// This file contains the code for accessing MySQL.
#if defined(_MYSQL_)
#include "HLDataBase.h"
#include "ServerLog.h"
#include "ServerConf.h"
#include <stdarg.h>
#include <sys/param.h>
#include "FileUtils.h"
#include "HLProtocol.h"

HLDataBase::HLDataBase()
	: mDB(0)
{
}

HLDataBase::~HLDataBase()
{
}

void HLDataBase::Connect(const string &inHost, const string &inUser,
			const string &inPassword, const string &inDataBase)
{
	MYSQL *db;
	
	DEBUG_CALL(printf("mysql_thread_safe(): %u\n", mysql_thread_safe()));

	mDB = mysql_init(mDB);
	if (!mDB)
	{
        ServerLog::ErrorLog(__FILE__, __LINE__, "mysql_init() failed: %s", mysql_error(mDB));
		return;
	}
	
	db = mysql_real_connect(mDB, inHost.c_str(), inUser.c_str(),
		inPassword.c_str(), inDataBase.c_str(), 0, 0, 0);
	if (!db)
	{
        ServerLog::ErrorLog(__FILE__, __LINE__, "mysql_real_connect() failed: %s\n", mysql_error(mDB));
		mDB = 0;
		return;
	}
	
	mDB = db;
}

void HLDataBase::Init()
{
	if (mDB == NULL)
		return;
	
	if (!gServer->Config().databaseReadOnly)
		Query("DELETE FROM userlist");
}

bool HLDataBase::Query(const char *inFormat, ...)
{
	if (mDB == NULL)
		return false;
	
	va_list ap;
	char buf[16384];

	va_start(ap, inFormat);
	vsnprintf(buf, sizeof(buf), inFormat, ap);
	va_end(ap);
	if (mysql_query(mDB, buf) != 0)
	{
        ServerLog::ErrorLog(__FILE__, __LINE__, "mysql_query() failed: %s query: %s", mysql_error(mDB), buf);
		return false;
	}
	return true;
}

void HLDataBase::AddUser(HLUser &inUser, const string &inAddress)
{
	if (mDB == NULL)
		return;
	
	if (!gServer->Config().databaseReadOnly)
	{
		char login[(USERNAME_MAXLEN + 1) * 2];
		char name[(USERNAME_MAXLEN + 1) * 2];
		time_t currentTime = time(0);
		
		mysql_real_escape_string(mDB, login, inUser.Login().c_str(), inUser.Login().length());
		mysql_real_escape_string(mDB, name, inUser.Name().c_str(), inUser.Name().length());
		if (Query("INSERT INTO userlist VALUES(%u, '%s', '%s', '%s', %u, %u, %u)", inUser.ID(),
			inAddress.c_str(), login, name, inUser.Icon(), inUser.Status(), currentTime))
		{
			if (Query("INSERT INTO connection VALUES(NULL, %u, %u, '%s', '%s', '%s', %u)",
				currentTime, 0, inAddress.c_str(), login, name, inUser.Icon()))
			{
				inUser.SetConnectionID((u_int32_t)mysql_insert_id(mDB));
			}
		}
	}
}

void HLDataBase::UpdateUser(HLUser &inUser)
{
	if (mDB == NULL)
		return;
	
	if (!gServer->Config().databaseReadOnly)
	{
		char name[(USERNAME_MAXLEN + 1) * 2];
		
		mysql_real_escape_string(mDB, name, inUser.Name().c_str(), inUser.Name().length());
		if (Query("UPDATE userlist SET name='%s', icon=%u, status=%d WHERE id=%u",
			name, inUser.Icon(), inUser.Status(), inUser.ID()))
		{
			Query("UPDATE connection SET name='%s', icon=%u WHERE id=%u",
				name, inUser.Icon(), inUser.ConnectionID());
		}
	}
}

void HLDataBase::RemoveUser(HLUser &inUser)
{
	if (mDB == NULL)
		return;
	
	if (!gServer->Config().databaseReadOnly)
	{
		if (Query("DELETE FROM userlist WHERE id=%u", inUser.ID()))
		{
			Query("UPDATE connection SET disconnect=%u WHERE id=%u",
				time(0), inUser.ConnectionID());
		}
	}
}

void HLDataBase::AddHostBan(const string &inHost, const HLUser &inUser,
    const string &inReason, const u_int32_t inTimeout)
{
	if (mDB == NULL)
		return;

	if (!gServer->Config().databaseReadOnly)
	{
		time_t currentTime = time(0);
		char reason[512];
        
		mysql_real_escape_string(mDB, reason, inReason.c_str(),
                inReason.length() > 255 ? 255 : inReason.length());
        
		Query("INSERT INTO banlist_ip VALUES('%s', %u, %u, %u, '%s')", inHost.c_str(), currentTime,
            inTimeout == 0 ? 0 : (currentTime + inTimeout), inUser.ConnectionID(), reason);
	}
}

void HLDataBase::AddNewsPost(HLUser &inUser, const string &inPost)
{
	if (mDB == NULL)
		return;
	
	if (!gServer->Config().databaseReadOnly)
	{
		char login[(USERNAME_MAXLEN + 1) * 2];
		char name[(USERNAME_MAXLEN + 1) * 2];
		time_t currentTime = time(0);
		char *post = new char[(inPost.length() * 2) + 1];
		
		mysql_real_escape_string(mDB, login, inUser.Login().c_str(), inUser.Login().length());
		mysql_real_escape_string(mDB, name, inUser.Name().c_str(), inUser.Name().length());
		mysql_real_escape_string(mDB, post, inPost.c_str(), inPost.length());
		Query("INSERT INTO news VALUES(NULL, '%s', '%s', '%s', %u)",
			name, login, post, currentTime);
		delete[] post;
	}
}

void HLDataBase::GetNewsFile(string &outNews)
{
	if (mDB == NULL)
		return;
	
	outNews.erase();
	if (Query("select name, timestamp, post from news where timestamp!=0 order by timestamp desc limit 50"))
	{
		MYSQL_RES *result = mysql_store_result(mDB);
		if (result)
		{
			MYSQL_ROW row;
			int num_fields;
			num_fields = mysql_num_fields(result);
			while ((row = mysql_fetch_row(result)))
			{
				unsigned long *lengths = mysql_fetch_lengths(result);
				string userName(row[0], lengths[0]);
				string tempNum(row[1], lengths[1]);
				string postString(row[2], lengths[2]);
				string tempString;
				time_t postTime = atol(tempNum.c_str());
				
                HLServer::FormatNewsPost(userName, postTime, postString, tempString);
				
                if (tempString.length() + outNews.length() > 0xFFFF)
					break;
				outNews.append(tempString);
			}
			mysql_free_result(result);
		}
	}
}

void HLDataBase::GetAgreement(string &outAgreement)
{
    if (mDB == NULL)
		return;
	
	outAgreement.erase();
	if (Query("select post from news where timestamp=0 limit 1"))
	{
		MYSQL_RES *result = mysql_store_result(mDB);
		if (result)
		{
			MYSQL_ROW row;
			if ((row = mysql_fetch_row(result)))
			{
				unsigned long *lengths = mysql_fetch_lengths(result);
				outAgreement.assign(row[0], lengths[0]);
                FileUtils::fixLineEndings(outAgreement);
			}
			mysql_free_result(result);
		}
	}
}

bool HLDataBase::IsHostBanned(const string &inHost)
{
	if (mDB == NULL)
		return false;
	
	bool isBanned = false;
	if (Query("select expire from banlist_ip where '%s' like concat(ip, '%%')", inHost.c_str()))
	{
		MYSQL_RES *result = mysql_store_result(mDB);
		if (result)
		{
            time_t currentTime = time(0);
			MYSQL_ROW row;
            unsigned int banExpire;
			while ((row = mysql_fetch_row(result)) && !isBanned)
            {
                banExpire = strtol(row[0], 0, 0);
                DEBUG_CALL(printf("banExpire: %u\n", banExpire); fflush(stdout));
                if (banExpire == 0 || banExpire >= (unsigned int)currentTime)
                    isBanned = true;
                else if (banExpire < (unsigned int)currentTime)
                {
                    DEBUG_CALL(printf("ban has expired for: %s\n", inHost.c_str()); fflush(stdout));
                    // i think this works, but i haven't tested it enough
                    Query("DELETE from banlist_ip WHERE (expire != 0 and expire < %u)", currentTime);
                }
			}
            mysql_free_result(result);
		}
	}
	return isBanned;
}

bool HLDataBase::IsNameBanned(const string &inName)
{
	if (mDB == NULL)
		return false;
	
	char name[(USERNAME_MAXLEN + 1) * 2];
	
	mysql_real_escape_string(mDB, name, inName.c_str(), inName.length());
	
	bool isBanned = false;
	if (Query("select 1 from banlist_name where upper(name)=upper('%s') LIMIT 1", name))
	{
		MYSQL_RES *result = mysql_store_result(mDB);
		if (result)
		{
			MYSQL_ROW row;
			if ((row = mysql_fetch_row(result)))
            	isBanned = true;
			mysql_free_result(result);
		}
	}
	return isBanned;
}

void HLDataBase::AddTransfer(HLTransfer &inTransfer)
{
	if (mDB == NULL)
		return;
	
	if (!gServer->Config().databaseReadOnly)
	{
		char filePath[(MAXPATHLEN + 1) * 2];
		time_t currentTime = time(0);
		
		u_int8_t transferType = 0;
		u_int32_t transferSize = 0;
		if (inTransfer.Info().type == kDownloadTransfer || inTransfer.Info().type == kPreviewTransfer)
		{
			if (inTransfer.Info().remoteDataForkSize)
			{
				// this is a resumed download
				transferType = 1;
			}
			else
			{
                // normal download
				transferType = 0;
			}
			transferSize = 
				(inTransfer.Info().localDataForkSize - inTransfer.Info().remoteDataForkSize);
		}
		else if (inTransfer.Info().type == kUploadTransfer)
		{
			if (inTransfer.Info().localDataForkSize)
			{
				// this is a resumed upload
				transferType = 3;
			}
			else
			{
                // normal upload
				transferType = 2;
			}
			transferSize = inTransfer.Info().transferSize;
		}
		
        string tempPath = inTransfer.Info().filePath;
        if (transferType == 2 || transferType == 3)
        {
            // if the transfer is an upload, i remove .hpf from the file path
            FileUtils::removeExtension(tempPath);
        }
		mysql_real_escape_string(mDB, filePath, tempPath.c_str(), tempPath.length());
		Query("INSERT INTO transfers VALUES(NULL, %u, '%s', %u, %u, %u, %u, %u)",
			transferType, filePath, currentTime, 0, transferSize,
			0, inTransfer.Info().user.ConnectionID());
	
		inTransfer.SetTransferID((u_int32_t)mysql_insert_id(mDB));
	}
}

void HLDataBase::EndTransfer(HLTransfer &inTransfer)
{
	if (mDB == NULL)
		return;
	
	if (!gServer->Config().databaseReadOnly && inTransfer.TransferID())
	{
		Query("UPDATE transfers SET transfered_bytes=%u, end_time=%u WHERE id=%u",
				inTransfer.TransferedBytes(), time(0), inTransfer.TransferID());
	}
}

void HLDataBase::CreateAccount(HLAccount &inAccount)
{
    if (mDB == NULL)
		return;
	
	if (gServer->Config().databaseReadOnly)
		return;
	
    char login[(USERNAME_MAXLEN + 1) * 2];
    char name[(USERNAME_MAXLEN + 1) * 2];
    char password[(USERNAME_MAXLEN + 1) * 2];
    
    mysql_real_escape_string(mDB, login, inAccount.Login().c_str(), inAccount.Login().length());
    mysql_real_escape_string(mDB, name, inAccount.Name().c_str(), inAccount.Name().length());
    mysql_real_escape_string(mDB, password, inAccount.Password().c_str(), inAccount.Password().length());
    u_int64_t permBits = *((u_int64_t *)&(inAccount.Access()));
    DEBUG_CALL(printf("create acct access bits: 0x%16llX\n", permBits); fflush(stdout));

    // id, login, password, name, email, create_date, permission_bits
    Query("INSERT INTO account VALUES(NULL, '%s', MD5('%s'), '%s', NULL, %u, %llu)",
        login, password, name, time(0), permBits);
}

bool HLDataBase::FetchAccount(HLAccount &ioAccount, bool inVerifyPassword)
{
	if (mDB == NULL)
		return false;
    
    // unused so far
    ioAccount.SetFilesPath(gServer->Config().rootPath);
	
	
	bool foundAccount = false;
	char login[(USERNAME_MAXLEN + 1) * 2];
    char password[(PASSWORD_MAXLEN + 1) * 2];
    
    mysql_real_escape_string(mDB, login, ioAccount.Login().c_str(), ioAccount.Login().length());
    
    if (inVerifyPassword && ioAccount.Password().length())
        mysql_real_escape_string(mDB, password, ioAccount.Password().c_str(), ioAccount.Password().length());
    // I was using openssl to MD5 the password before constructing the query string
    // because i figured this would be safer than sending the plain text password
    // to mysql, but i have temporarily removed this in favor of mysql's MD5() function
	//MD5String(inPassword, md5String);
	
    // this is very ugly =)
    bool queryOk = false;
	if (inVerifyPassword)
    {
        queryOk = Query(ioAccount.Password().length() ?
            "select id, name, permission_bits from account where login='%s' and password=MD5('%s')" :
            "select id, name, permission_bits from account where login='%s' and password=''", login, password);
    }
    else
        queryOk = Query("select id, name, permission_bits from account where login='%s'", login);
    
    if (queryOk)
	{
		MYSQL_RES *result = mysql_store_result(mDB);
		if (result)
		{
			MYSQL_ROW row;
			int num_fields;
			num_fields = mysql_num_fields(result);
            u_int32_t accountID = 0;
            bool gotPermissions = false;
			if ((row = mysql_fetch_row(result)))
			{
				unsigned long *lengths;
				lengths = mysql_fetch_lengths(result);
				
				string temp(row[0], lengths[0]);
                string acctName(row[1], lengths[1]);
                ioAccount.SetName(acctName);
				accountID = strtoul(temp.c_str(), 0, 10);
                if (row[2] != NULL && lengths[2] > 0)
                {
                    string permissions(row[2], lengths[2]);
                    u_int64_t permBits = htonll(strtoull(permissions.c_str(), 0, 10));
                    if (permBits)
                    {
                        DEBUG_CALL(printf("account access bits: 0x%16llX\n", permBits); fflush(stdout));
                        struct hl_access_bits realAccess = *((struct hl_access_bits *)&permBits);
                        ioAccount.SetAccess(realAccess);
                        gotPermissions = true;
                    }
                }
				foundAccount = true;
			}
			
            mysql_free_result(result);
			
            // if we didn't get permissions from the account table
            // then we look for them in the roster/groop tables
            if (!gotPermissions)
            {
                if (Query("select groop_id from roster where account_id=%u", accountID))
                {
                    if ((result = mysql_store_result(mDB)))
                    {
                        // right now i'm just going to support an account being
                        // in one group, but the plan is to allow accounts to be
                        // in any number of groups
                        
                        // this should be a list of group ids
                        u_int32_t groupID = 0;
                        while ((row = mysql_fetch_row(result)))
                        {
                            unsigned long *lengths;
                            lengths = mysql_fetch_lengths(result);
                            
                            string temp(row[0], lengths[0]);
                            groupID = strtoul(temp.c_str(), 0, 10);
                        }
                        
                        mysql_free_result(result);
                        if (Query("select permission_bits, max_bps from groop where id=%u", groupID))
                        {
                            if ((result = mysql_store_result(mDB)))
                            {
                                if ((row = mysql_fetch_row(result)))
                                {
                                    unsigned long *lengths;
                                    lengths = mysql_fetch_lengths(result);
                                    string temp(row[0], lengths[0]);
                                    u_int64_t bits = htonll(strtoull(temp.c_str(), 0, 10));
                                    DEBUG_CALL(printf("mysql access bits: 0x%16llX\n", bits); fflush(stdout));
                                    struct hl_access_bits realAccess = *((struct hl_access_bits *)&bits);
                                    ioAccount.SetAccess(realAccess);
                                    
                                    temp.assign(row[1], lengths[1]);
                                    ioAccount.SetMaxBps(strtoul(temp.c_str(), 0, 10));
                                }
                                mysql_free_result(result);
                            }
                        }
                    }
                }
            } // not gotPermissions
		}
	}
	return foundAccount;
}

void HLDataBase::DeleteAccount(const string &inLogin)
{
    if (mDB == NULL)
		return;
	
	if (gServer->Config().databaseReadOnly)
		return;
	
	char login[(USERNAME_MAXLEN + 1) * 2];
    mysql_real_escape_string(mDB, login, inLogin.c_str(), inLogin.length());
    // eventually this should get the group id and if it is a custom
    // group for only this user, then it should delete that group also
    Query("delete from account where login='%s'", login);
}

void HLDataBase::ModifyAccount(HLAccount &inAccount, bool inChangePassword)
{
    if (mDB == NULL)
		return;
	
	if (gServer->Config().databaseReadOnly)
		return;
    
    char login[(USERNAME_MAXLEN + 1) * 2];
    char name[(USERNAME_MAXLEN + 1) * 2];
    char password[(USERNAME_MAXLEN + 1) * 2];
    
    mysql_real_escape_string(mDB, login, inAccount.Login().c_str(), inAccount.Login().length());
    mysql_real_escape_string(mDB, name, inAccount.Name().c_str(), inAccount.Name().length());
    if (inChangePassword)
        mysql_real_escape_string(mDB, password, inAccount.Password().c_str(), inAccount.Password().length());
    u_int64_t permBits = ntohll(*((u_int64_t *)&(inAccount.Access())));
    DEBUG_CALL(printf("modify acct access bits: 0x%16llX\n", permBits); fflush(stdout));

    // id, login, password, name, email, create_date, permission_bits
    if (inChangePassword)
    {
        Query("UPDATE account SET password=MD5('%s'), name='%s', permission_bits=%llu WHERE login='%s' LIMIT 1",
            password, name, permBits, login);
    }
    else
    {
		Query("UPDATE account SET name='%s', permission_bits=%llu WHERE login='%s' LIMIT 1",
            name, permBits, login);
	}    
}

void HLDataBase::ListAccounts(AccountList &outAccountList)
{
	// TODO: need to implement this - jjt
}

#endif // _MYSQL_

