#include "comserver.h"
#pragma comment(linker, "/EXPORT:DllRegisterServer=_DllRegisterServer@0,PRIVATE")
#pragma comment(linker, "/EXPORT:DllUnregisterServer=_DllUnregisterServer@0,PRIVATE")
#pragma comment(linker, "/EXPORT:DllCanUnloadNow=_DllCanUnloadNow@0,PRIVATE")
#pragma comment(linker, "/EXPORT:DllGetClassObject=_DllGetClassObject@12,PRIVATE")

//Standard COM server
//See "comserver.h" for usage

// CServer class holds global object count
class CServer
{
public:
  CServer() : m_hInstance(NULL), m_dwRef(0)
  {}
  HINSTANCE m_hInstance;
  DWORD m_dwRef;
};
CServer m_Server;

// Template class that provides basic IUnknown implementation
template <class T, const IID* piid>
class CInterface : public T
{
public:
  CInterface() : m_dwRef(0)
  { m_Server.m_dwRef++; }
  virtual ~CInterface()
  { m_Server.m_dwRef--; }

  STDMETHOD(QueryInterface)(REFIID riid, void** ppvObject)
  {
    if ((riid == IID_IUnknown) || (riid == *piid))
    {
      *ppvObject = (T*)static_cast<T*>(this);
      m_dwRef++;
      return S_OK;
    }
    return E_NOINTERFACE;
  }
  STDMETHOD_(ULONG,AddRef)()
  { return ++m_dwRef; }
  STDMETHOD_(ULONG,Release)()
  {
    if (!(--m_dwRef))
    {
      delete this;
      return 0;
    }
    return m_dwRef;
  }
  DWORD m_dwRef;
};

// The COM class
class ComClass : public CInterface<IDispatch,&IID_IDispatch>
{
// IDispatch interface implementation
public:
    // These methods not implemented
    STDMETHOD(GetTypeInfoCount)(UINT* pctinfo)
    { return E_NOTIMPL;  }
    STDMETHOD(GetTypeInfo)(UINT, LCID, ITypeInfo**)
    { return E_NOTIMPL;  }

    //Method definition
    STDMETHOD(GetIDsOfNames)(REFIID, LPOLESTR* rgszNames,
               UINT cNames, LCID, DISPID* rgDispId)
    {
      return COM_GetIDsOfNames(rgszNames, cNames, rgDispId, data);
    }

    //Method Invoke
    STDMETHOD(Invoke)(DISPID dispIdMember, REFIID riid, LCID lcid, WORD wFlags,
            DISPPARAMS* pDispParams, VARIANT* pVarResult,
            EXCEPINFO* pExcepInfo, UINT *puArgErr)
    {
      return COM_Invoke(dispIdMember, riid, lcid, wFlags,
         pDispParams, pVarResult, pExcepInfo, puArgErr, data);
    }
    ComClass(){
        data=new COMDATA(this);
    }
    ~ComClass(){
        delete data;
    }
    COMDATA* data;
};

// IDidpatch handling for getting data from object.
COMDATA* GetDataFromComClass(IDispatch* pDisp){
	ComClass* temp=(ComClass*)pDisp;
	return temp->data;
}

// The path where this DLL file exists.
LPSTR ModulePath(){
	static char szModulePath[_MAX_PATH];
	static LPSTR pModulePath=NULL;
	if (!pModulePath) {
		GetModuleFileNameA(m_Server.m_hInstance,szModulePath,_MAX_PATH);
		pModulePath=szModulePath;
	}
	return pModulePath;
}
LPSTR ModuleDir(){
	static char szModulePath[_MAX_PATH];
	static LPSTR pModulePath=NULL;
	if (!pModulePath) {
		GetModuleFileNameA(m_Server.m_hInstance,szModulePath,_MAX_PATH);
		pModulePath=szModulePath;
	}
	for (int i=strlen(pModulePath);0<=i && pModulePath[i]!='\\';i--) pModulePath[i]='\0';
	return pModulePath;
}


// Class factory to create COM objects
class CClassFactory :
  public CInterface<IClassFactory,&IID_IClassFactory>
{
public:
// IClassFactory interface implementation
  STDMETHOD(CreateInstance)(IUnknown* pUnkOuter, REFIID riid,
                void** ppvObject)
  {
    if (pUnkOuter)
      return CLASS_E_NOAGGREGATION;
    ComClass* pObject = new ComClass;
    HRESULT hr = pObject->QueryInterface(riid,ppvObject);
    if (FAILED(hr))
      delete pObject;
    return hr;
  }
  STDMETHOD(LockServer)(BOOL fLock)
  { return CoLockObjectExternal(this,fLock,TRUE); }
};

BOOL APIENTRY DllMain( HMODULE hInstance,
                       DWORD  ul_reason_for_call,
                       LPVOID lpReserved
					 )
{
	switch (ul_reason_for_call)
	{
	case DLL_PROCESS_ATTACH:
		m_Server.m_hInstance = hInstance;
		DisableThreadLibraryCalls(hInstance);
		break;
	case DLL_THREAD_ATTACH:
	case DLL_THREAD_DETACH:
	case DLL_PROCESS_DETACH:
		break;
	}
	return TRUE;
}

// Required COM in-proc server exports
STDAPI DllUnregisterServer(void)
{
    HRESULT hr = S_OK;
    HKEY key = NULL;
    if (!RegCreateKeyExA(HKEY_CLASSES_ROOT,CLSIDKEY,0,NULL,REG_OPTION_NON_VOLATILE,
                        KEY_ALL_ACCESS,NULL,&key,NULL))
    {
        char szModulePath[_MAX_PATH];
        GetModuleFileNameA(m_Server.m_hInstance,szModulePath,_MAX_PATH);
        if(!RegSetValueExA(key,NULL,0,REG_SZ,(const unsigned char*)szModulePath,
                          strlen(szModulePath)+1))
        {
          if (!RegCreateKeyExA(HKEY_CLASSES_ROOT,PRODKEY,0,NULL,REG_OPTION_NON_VOLATILE,
                            KEY_ALL_ACCESS,NULL,&key,NULL))
          {
            if (RegDeleteKeyA(HKEY_CLASSES_ROOT,PRODIDKEY)) hr = E_FAIL;
            if (RegDeleteKeyA(HKEY_CLASSES_ROOT,PRODKEY)) hr = E_FAIL;
          }
          RegCloseKey(key);
        }
        if (RegDeleteKeyA(HKEY_CLASSES_ROOT,CLASSKEY)) hr = E_FAIL;
        if (RegDeleteKeyA(HKEY_CLASSES_ROOT,CLSIDKEY)) hr = E_FAIL;
    }
    RegCloseKey(key);
    return hr;
}
STDAPI DllRegisterServer(void)
{
    HRESULT hr = E_FAIL;
    HKEY key = NULL;
    if (!RegCreateKeyExA(HKEY_CLASSES_ROOT,CLASSKEY,0,NULL,REG_OPTION_NON_VOLATILE,
                        KEY_ALL_ACCESS,NULL,&key,NULL))
    {
        char szModulePath[_MAX_PATH];
        GetModuleFileNameA(m_Server.m_hInstance,szModulePath,_MAX_PATH);
        if(!RegSetValueExA(key,NULL,0,REG_SZ,(const unsigned char*)szModulePath,
                          strlen(szModulePath)+1))
        {
          RegCloseKey(key);
          if (!RegCreateKeyExA(HKEY_CLASSES_ROOT,PRODIDKEY,0,NULL,REG_OPTION_NON_VOLATILE,
                            KEY_ALL_ACCESS,NULL,&key,NULL))
          {
            if (!RegSetValueExA(key,NULL,0,REG_SZ,(const unsigned char*)CLSIDVAL,
                          strlen(CLSIDVAL)+1))
              hr = S_OK;
          }
        }
    }
    RegCloseKey(key);
    return hr;
}

STDAPI DllGetClassObject(REFCLSID rclsid, REFIID riid, LPVOID* ppv)
{
  HRESULT hr = CLASS_E_CLASSNOTAVAILABLE;
  if (rclsid == CLSID_COM)
  {
    CClassFactory* pFactory = new CClassFactory;
    if (FAILED(hr = pFactory->QueryInterface(riid,ppv)))
      delete pFactory;
    hr = S_OK;
  }
  return hr;
}

STDAPI DllCanUnloadNow()
{ return (m_Server.m_dwRef) ? S_FALSE : S_OK; }