#!/usr/bin/python

import sys, os, glob, os.path, shutil, re
from optparse import OptionParser

glob = glob.glob

def cmd(c):
    if os.system(c) != 0:
        raise Exception('command execution failed: ' + c)

parser = OptionParser(usage = 'usage: %prog [-v VERSION][-l LINUX]')
parser.add_option('-v', action = 'store', type = 'string', dest = 'version', \
                  help = 'kvm-kmod release version', default = 'kvm-devel')
parser.add_option('-l', action = 'store', type = 'string', dest = 'linux', \
                  help = 'Linux kernel tree to sync from', \
                  default = 'linux-2.6')
parser.set_defaults()
(options, args) = parser.parse_args()
version = options.version
linux = options.linux

_re_cache = {}

def re_cache(regexp):
    global _re_cache
    if regexp not in _re_cache:
        _re_cache[regexp] = re.compile(regexp)
    return _re_cache[regexp]

def hack_content(data):
    compat_apis = str.split(
        'INIT_WORK desc_struct ldttss_desc64 desc_ptr '
        'hrtimer_add_expires_ns hrtimer_get_expires '
        'hrtimer_get_expires_ns hrtimer_start_expires '
        'hrtimer_expires_remaining smp_send_reschedule '
        'on_each_cpu relay_open request_irq free_irq '
        'init_srcu_struct cleanup_srcu_struct srcu_read_lock '
        'srcu_read_unlock synchronize_srcu srcu_batches_completed '
        'do_machine_check get_desc_base get_desc_limit '
        'vma_kernel_pagesize native_read_tsc user_return_notifier '
        'user_return_notifier_register user_return_notifier_unregister '
        'synchronize_srcu_expedited getboottime monotonic_to_bootbased '
        'check_tsc_unstable native_store_gdt native_store_idt '
        'set_desc_base set_desc_limit '
        )
    kvm_init = kvm_exit = False
    mce = False
    eventfd_file = eventfd_ioctl = False
    result = []
    pr_fmt = ''
    def sub(regexp, repl, str):
        return re_cache(regexp).sub(repl, str)
    for line in data.splitlines():
        orig = line
        def match(regexp):
            return re_cache(regexp).search(line)
        def w(line, result = result):
            result.append(line)
        f = line.split()
        if match(r'^#define pr_fmt'):
            pr_fmt = sub(r'#define pr_fmt\([^)]*\) ("[^"]*").*', r'\1', line) + ' '
            line = ''
        line = sub(r'pr_debug\(([^),]*)', r'pr_debug(' + pr_fmt + r'\1', line)
        if match(r'^int kvm_init\('): kvm_init = True
        if match(r'r = kvm_arch_init\(opaque\);') and kvm_init:
            w('\tr = kvm_init_anon_inodes();')
            w('\tif (r)')
            w('\t\treturn r;\n')
            w('\tr = kvm_init_srcu();')
            w('\tif (r)')
            w('\t\tgoto cleanup_anon_inodes;\n')
            w('\tpreempt_notifier_sys_init();')
            w('\thrtimer_kallsyms_resolve();\n')
        if match(r'return 0;') and kvm_init:
            w('\tprintk("loaded kvm module (%s)\\n");\n' % (version,))
            w('\tkvm_clock_warn_suspend_bug();\n')
        if match(r'return r;') and kvm_init:
            w('\tpreempt_notifier_sys_exit();')
            w('\tkvm_exit_srcu();')
            w('cleanup_anon_inodes:')
            w('\tkvm_exit_anon_inodes();')
            kvm_init = False
        if match(r'^void kvm_exit'): kvm_exit = True
        if match(r'\}') and kvm_exit:
            w('\tpreempt_notifier_sys_exit();')
            w('\tkvm_exit_srcu();')
            w('\tkvm_exit_anon_inodes();')
            kvm_exit = False
        if match(r'^int kvm_arch_init'): kvm_arch_init = True
        if match(r'\btsc_khz\b') and kvm_arch_init:
            line = sub(r'\btsc_khz\b', 'kvm_tsc_khz', line)
        if match(r'^}'): kvm_arch_init = False
        if match(r'MODULE_AUTHOR'):
            w('MODULE_INFO(version, "%s");' % (version,))
        line = sub(r'(\w+)->dev->msi_enabled',
                   r'kvm_pcidev_msi_enabled(\1->dev)', line)
        if match(r'atomic_inc\(&kvm->mm->mm_count\);'):
            line = '\tmmget(&kvm->mm->mm_count);'
        if match(r'^\t\.fault = '):
            fcn = sub(r',', '', f[2])
            line = '\t.VMA_OPS_FAULT(fault) = VMA_OPS_FAULT_FUNC(' + fcn + '),'
        if match(r'^static int (.*_stat_get|lost_records_get)'):
            line = line[0:11] + '__' + line[11:]
        if match(r'DEFINE_SIMPLE_ATTRIBUTE.*(_stat_get|lost_records_get)'):
            name = sub(r',', '', f[1])
            w('MAKE_SIMPLE_ATTRIBUTE_GETTER(' + name + ')')
        line = sub(r'linux/mm_types\.h', 'linux/mm.h', line)
        line = sub(r'\b__user\b', ' ', line)
        if match(r'^\t\.name = "kvm"'):
            line = '\tset_kset_name("kvm"),'
        if match(r'#include <linux/compiler.h>'):
            line = ''
        if match(r'#include <linux/clocksource.h>'):
            line = ''
        if match(r'#include <linux\/types.h>'):
            line = '#include <asm/types.h>'
        if match(r'\t\.change_pte.*kvm_mmu_notifier_change_pte,'):
            line = '#ifdef MMU_NOTIFIER_HAS_CHANGE_PTE\n' + line + '\n#endif'
        if match(r'static void kvm_mmu_notifier_change_pte'):
            line = sub(r'static ', '', line)
            line = '#ifdef MMU_NOTIFIER_HAS_CHANGE_PTE\n' + 'static\n' + '#endif\n' + line
        line = sub(r'\bhrtimer_init\b', 'hrtimer_init_p', line)
        line = sub(r'\bhrtimer_start\b', 'hrtimer_start_p', line)
        line = sub(r'\bhrtimer_cancel\b', 'hrtimer_cancel_p', line)
        if match(r'case KVM_CAP_SYNC_MMU'):
            line = '#ifdef CONFIG_MMU_NOTIFIER\n' + line + '\n#endif'
        for ident in compat_apis:
            line = sub(r'\b' + ident + r'\b', 'kvm_' + ident, line)
        if match(r'kvm_.*_fops\.owner = module;'):
            line = 'IF_ANON_INODES_DOES_REFCOUNTS(' + line + ')'
        if not match(r'#include'):
            line = sub(r'\blapic\n', 'l_apic', line)
        if match(r'struct pt_regs regs'):
            mce = True
        if mce and match(r'\.cs'):
            line = sub(r'cs', r'kvm_pt_regs_cs', line)
        if mce and match(r'\.flags'):
            line = sub(r'flags', r'kvm_pt_regs_flags', line)
            mce = False
        line = sub(r'boot_cpu_data.x86_phys_bits', 'kvm_x86_phys_bits', line)
        if match(r'^static const struct vm_operations_struct kvm_'):
            line = sub(r' const ', ' ', line)
        if line == 'static void *nested_svm_map(struct vcpu_svm *svm, u64 gpa, enum km_type idx)':
            line = sub(r'\)', ', struct page **mapped_page)', line)
        if line == '\treturn kmap_atomic(page, idx);':
            line = '\t*mapped_page = page;\n' + line
        if line == 'static void nested_svm_unmap(void *addr, enum km_type idx)':
            line = sub(r'\)', ', struct page *mapped_page)', line)
        if line == '\tpage = kmap_atomic_to_page(addr);':
            line = '\tpage = mapped_page;'
        if match(r'= nested_svm_map(.*);'):
            line = '\t{ struct page *mapped_page;\n' + sub(r'\);', ', &mapped_page);', line)
        if match('nested_svm_unmap(.*);'):
            line = sub(r'\);', ', mapped_page); }', line)
        if line == 'struct _irqfd {':
            w('#if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,33)')
            eventfd_file = True
        if line in ['\tcase KVM_IRQFD: {', '\tcase KVM_IOEVENTFD: {']:
            w('#if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,33)')
            eventfd_ioctl = True
        if line == '\tcase KVM_CAP_IRQFD:':
            w('#if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,33)')
        if match(r'#include <asm/desc_defs.h>'):
            line = ''
        w(line)
        if match(r'apic->timer.dev.function ='):
            w('\thrtimer_data_pointer(&apic->timer.dev);')
        if match(r'pt->timer.function ='):
            w('\thrtimer_data_pointer(&pt->timer);')
        if line == '\tkvm_arch_vcpu_put(vcpu);':
            w('\tkvm_fire_urn();')
        if line == '\t}' and eventfd_ioctl:
            w('#endif')
            eventfd_ioctl = False
        if line == '\tcase KVM_CAP_IOEVENTFD:':
            w('#endif')
    if eventfd_file:
        result.append('#else\n'
                      'void kvm_eventfd_init(struct kvm *kvm) { }\n'
                      'void kvm_irqfd_release(struct kvm *kvm) { }\n'
                      '#endif')
    data = str.join('', [line + '\n' for line in result])
    return data

def hack_file(T, fname):
    fname = T + '/' + fname
    data = file(fname).read()
    data = hack_content(data)
    file(fname, 'w').write(data)

def unifdef(fname):
    data = file('unifdef.h').read() + file(fname).read()
    file(fname, 'w').write(data)

hack_files = {
    'x86': str.split('kvm_main.c mmu.c vmx.c svm.c x86.c irq.h lapic.c'
                     ' i8254.c timer.c eventfd.c emulate.c'),
    'ia64': str.split('kvm_main.c kvm_fw.c kvm_lib.c kvm-ia64.c'),
}

def mkdir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

def cp(src, dst):
    mkdir(os.path.dirname(dst))
    file(dst, 'w').write(file(src).read())

def copy_if_changed(src, dst):
    for dir, subdirs, files in os.walk(src):
        ndir = dst + '/' + dir[len(src)+1:]
        mkdir(ndir)
        for fname in files:
            old = ndir + '/' + fname
            new = dir + '/' + fname
            try:
                if file(old).read() !=  file(new).read():
                    raise Exception('different.')
            except:
                cp(new, old)

def rmtree(path):
    if os.path.exists(path):
        shutil.rmtree(path)

def header_sync(arch):
    T = 'header'
    rmtree(T)
    for file in glob('%(linux)s/include/linux/kvm*.h' % { 'linux': linux }):
        out = ('%(T)s/include/linux/%(name)s'
               % { 'T': T, 'name': os.path.basename(file) })
        cp(file, out)
        unifdef(out)
    for file in glob(('%(linux)s/include/trace/events/kvm*.h'
                      % { 'linux': linux })):
        out = ('%(T)s/include/trace/events/%(name)s'
               % { 'T': T, 'name': os.path.basename(file) })
        cp(file, out)
        unifdef(out)
    arch_headers = (
        [x
         for dir in ['%(linux)s/arch/%(arch)s/include/asm/./kvm*.h',
                     '%(linux)s/arch/%(arch)s/include/asm/./vmx*.h',
                     '%(linux)s/arch/%(arch)s/include/asm/./svm*.h',
                     '%(linux)s/arch/%(arch)s/include/asm/./virtext*.h',
                     '%(linux)s/arch/%(arch)s/include/asm/./hyperv.h']
         for x in glob(dir % { 'arch': arch, 'linux': linux })
         ])
    for file in arch_headers:
        out = ('%(T)s/include/asm-%(arch)s/%(name)s'
               % { 'T': T, 'name': os.path.basename(file), 'arch': arch })
        cp(file, out)
        unifdef(out)
    hack_file(T, 'include/linux/kvm.h')
    hack_file(T, 'include/asm-%(arch)s/kvm.h' % { 'arch': arch })
    hack_file(T, 'include/asm-%(arch)s/kvm_host.h' % { 'arch': arch })
    if arch == 'x86':
        hack_file(T, 'include/asm-x86/kvm_emulate.h')
    copy_if_changed(T, '.')
    rmtree(T)

def source_sync(arch):
    T = 'source'
    rmtree(T)
    sources = [file
               for pattern in ['%(linux)s/arch/%(arch)s/kvm/*.[cSh]',
                               '%(linux)s/virt/kvm/*.[cSh]']
               for file in glob(pattern % { 'linux': linux, 'arch': arch })
               if not file.endswith('.mod.c')
               ]
    for file in sources:
        out = ('%(T)s/%(name)s'
               % { 'T': T, 'name': os.path.basename(file) })
        cp(file, out)

    for i in glob(T + '/*.c'):
        unifdef(i)

    for i in hack_files[arch]:
        hack_file(T, i)

    copy_if_changed(T, arch)
    rmtree(T)

for arch in ['x86', 'ia64']:
    header_sync(arch)
    source_sync(arch)
