#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2025  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
"""
Calculates forces and lattice derivatives with finite differences.
"""

import argparse
from derivtools import readgen, writegen, cart2frac, frac2cart, stress2latderivs, AA__BOHR, exists, create_temporary_copy
import numpy as np
import os
import re
import subprocess
import shutil

DESCRIPTION = """
Calculates the forces using the specified DFTB+ binary by finite differences
displacing the atoms along every axis. The geometry of the configuration must
be specified in a file called 'geo.gen.template'. The DFTB+ input file should
include the geometry from the file 'geo.gen'. The input file must specify the
option for writing a results file and if the reference should be calculated
then also the option for calculating the forces.
"""

ENERGY_PATTERN = re.compile(r"mermin_energy[^:]*:[^:]*:[^:]*:\s*(?P<value>\S+)")

FORCES_PATTERN = re.compile(
    r"forces[^:]*:[^:]*:[^:]*:\d+,\d+\s*"
    r"(?P<values>(?:\s*[+-]?\d+(?:\.\d+(?:E[+-]?\d+)?)?)+)",
    re.MULTILINE)

LATTICE_DERIV_PATTERN = re.compile(
    r"^stress[^:]*:[^:]*:[^:]*:\d+,\d+\s*"
    r"(?P<values>(?:\s*[+-]?\d+(?:\.\d+(?:E[+-]?\d+)?)?)+)", re.MULTILINE)

DIPOLE_PATTERN = re.compile(r"dipole_moments\s*:[^:]*:[^:]*:\d+,\d+\s*"
    r"(?P<dipoleX>\S+)\s+(?P<dipoleY>\S+)\s+(?P<dipoleZ>\S+)")

CHARGE_PATTERN = re.compile(r"gross_atomic_charges\s*:[^:]*:[^:]*:\d+\s*"
    r"(?P<values>(?:\s*[+-]?\d+(?:\.\d+(?:E[+-]?\d+)?)?)+)",
    re.MULTILINE)

TAGGED_RESULTS = 'results.tag'
REFERENCE_RESULTS = 'results0.tag'
CHARGES_FILE = "charges.bin"

def main():
    """Main routine"""

    args = parse_arguments()
    directions= directions_from_args(args.directions)
    specienames, species, coords, origin, latvecs = readgen("geo.gen.template")
    atoms = atoms_from_args(args.atoms, coords.shape[0])
    if (exists(CHARGES_FILE)):
        charge_restart = create_temporary_copy(CHARGES_FILE)
    else:
        charge_restart = None
    calcforces = not args.skipforces
    calclatderivs = (latvecs is not None) and not args.skiplattice
    calcborn = args.doborn
    calccharges = args.docharges
    if args.calcref:
        reffile = REFERENCE_RESULTS
    elif args.ref is not None:
        reffile = args.ref
    else:
        reffile = None
    if (exists(TAGGED_RESULTS) and reffile != TAGGED_RESULTS):
        print(f"Removing pre-existing tagged file, {TAGGED_RESULTS}, before calculations")
        os.remove(TAGGED_RESULTS)
    disp = args.disp * AA__BOHR
    binary = args.binary
    print(f"DFTB+ BINARY: {binary}")

    if args.calcref:
        calculate_reference(binary, reffile, coords, specienames, species,
                            origin, latvecs, charge_restart)

    if reffile is not None:
        forces0, latderivs0 = read_reference_results(reffile, calcforces,
                                                     calclatderivs, latvecs)
    else:
        forces0 = latderivs0 = None

    if calcforces:
        forces = calculate_forces(binary, disp, coords, specienames,
                                  species, origin, latvecs, charge_restart,
                                  directions, atoms)

    if latvecs is not None and calclatderivs:
        latderivs = calculate_latderivs(binary, disp, coords, latvecs,
                                        specienames, species, origin,
                                        charge_restart, directions)

    if calcborn:
        born = calculate_born(binary, disp, coords, specienames,
                              species, origin, latvecs,
                              charge_restart)

    if calccharges:
        charges = calculate_charges(binary, disp, coords, specienames,
                              species, origin, latvecs,
                              charge_restart)

    if calcforces:
        print_forces(forces, forces0, directions, atoms)

    if latvecs is not None and calclatderivs:
        print_latderivs(latderivs, latderivs0, directions)

    if calccharges:
        print_charges(charges)

    if calcborn:
        print_born(born)

    if (charge_restart):
        shutil.copy2(charge_restart.name, CHARGES_FILE)

    shutil.copy2("geo.gen.template", "geo.gen")

def parse_arguments():
    """Parses command line arguments"""

    parser = argparse.ArgumentParser(description=DESCRIPTION)

    msg = "Specify the displacement of the atoms (unit: ANGSTROM)"
    parser.add_argument("-d", "--displacement", type=float, dest="disp",
                        default=1e-5, help=msg)

    msg = "Compare derivatives with those in reference results file"
    parser.add_argument("-r", "--reference", dest="ref", help=msg)

    msg = "Calculate reference system (and compare derivatives with it),"\
          f" resulting results file will be saved as '{REFERENCE_RESULTS}'"
    parser.add_argument("-c", "--calc-reference", dest="calcref",
                        action="store_true", default=False, help=msg)

    msg = "Skip the calculation of the lattice derivatives in the case of"\
          " periodic systems"
    parser.add_argument("-L", "--skip-lattice", dest="skiplattice",
                        action="store_true", default=False, help=msg)

    msg = "Skip the calculation of forces (useful when only lattice"\
          " derivatives should be calculated)"
    parser.add_argument("-F", "--skip-forces", dest="skipforces",
                        action="store_true", default=False, help=msg)

    msg = "Calculate the effective Born charges"
    parser.add_argument("-b", "--born-charges", dest="doborn",
                        action="store_true", default=False, help=msg)

    msg = "Calculate derivatives of atomic charges"
    parser.add_argument("-q", "--charges", dest="docharges",
                        action="store_true", default=False, help=msg)

    msg = "Directions of calculation (arguments like 'xyz' 'xy', etc.)"\
          "input order not maintained, default: xyz"
    parser.add_argument("--directions", dest="directions", default="xyz",
                        help=msg)
    msg = "Atoms for calculation, starting with one and including both limits,"\
          " use ':' for ranges, ',' for separation and '1:-x' for excluding"\
          " the last x atoms. Output order will be sorted smallest to "\
          "highest (arguments like '1,2' '1:5', '2,4:6,9' and '1,3:4,9:-4')"
    parser.add_argument("-a", "--atoms", dest="atoms", default=None,
                        help=msg)

    msg = "DFTB+ binary"
    parser.add_argument("binary", help=msg)

    args = parser.parse_args()
    if args.ref is not None and args.calcref:
        msg = "Specifying a reference file and requesting a calculation of the"\
              " reference system are mutually exclusive options"
        parser.error(msg)
    return args


def read_reference_results(results0, calcforces, calclatderivs, latvecs):
    """Reads in reference results"""

    forces0 = None
    latderivs0 = None
    fp = open(results0, "r")
    txt = fp.read()
    fp.close()
    if calcforces:
        match = FORCES_PATTERN.search(txt)
        if match:
            tmp = np.fromstring(match.group("values"), count=-1, dtype=float,
                                sep=" ")
            forces0 = tmp.reshape((-1, 3))
        else:
            raise ValueError("No forces found in reference file!")
    if calclatderivs:
        match = LATTICE_DERIV_PATTERN.search(txt)
        if match:
            tmp = np.fromstring(match.group("values"), count=-1, dtype=float,
                                sep=" ")
            stress0 = tmp.reshape((-1, 3))
            latderivs0 = stress2latderivs(stress0, latvecs)
        else:
            raise ValueError("No lattice derivatives found in reference file!")
    return forces0, latderivs0


def calculate_reference(binary, reffile, coords, specienames, species, origin,
                        latvecs, charge_restart):
    """Calculates reference system"""

    writegen("geo.gen", (specienames, species, coords, origin, latvecs))
    if (charge_restart):
        shutil.copy2(charge_restart.name, CHARGES_FILE)
    subprocess.call([binary])
    shutil.move(TAGGED_RESULTS, reffile)


def calculate_forces(binary, disp, coords, specienames, species, origin,
                     latvecs, charge_restart, directions, atoms):
    """Calculates forces by finite differences"""

    cart =('x','y','z')
    delta =('-','+')
    energy = np.empty((2,), dtype=float)
    forces = np.empty((len(atoms), len(directions)), dtype=float)
    for aa, iat in enumerate(atoms):
        for ii, coord in enumerate(directions):
            for jj in range(2):
                newcoords = np.array(coords)
                newcoords[iat][coord] += float(2 * jj - 1) * disp
                writegen("geo.gen", (specienames, species, newcoords, origin,
                                     latvecs))
                if (charge_restart):
                    shutil.copy2(charge_restart.name, CHARGES_FILE)
                subprocess.call([binary])
                fp = open(TAGGED_RESULTS, "r")
                txt = fp.read()
                fp.close()
                match = ENERGY_PATTERN.search(txt)
                print("iat: %2d, dir: %s, delta: %s" % (iat + 1, cart[coord], delta[jj]))
                if match:
                    energy[jj] = float(match.group("value"))
                    print("energy:", energy[jj])
                else:
                    raise ValueError(f"No energy match found in {TAGGED_RESULTS}!")
            forces[aa][ii] = (energy[0] - energy[1]) / (2.0 * disp)
    return forces


def calculate_latderivs(binary, disp, coords, latvecs, specienames, species,
                        origin, charge_restart, directions):
    """Calculates lattice derivatives by finite differences"""

    cart =('x','y','z')
    delta =('-','+')
    energy = np.empty((2,), dtype=float)
    latderivs = np.empty((3, len(directions)), dtype=float)
    for ii, coord in enumerate(directions):
        for jj in range(3):
            for kk in range(2):
                newcoords = np.array(coords)
                newcoords = cart2frac(latvecs, newcoords)
                newvecs = np.array(latvecs)
                newvecs[jj][coord] += float(2 * kk - 1) * disp
                newcoords = frac2cart(newvecs, newcoords)
                writegen("geo.gen", (specienames, species, newcoords, origin,
                                     newvecs))
                if (charge_restart):
                    shutil.copy2(charge_restart.name, CHARGES_FILE)
                subprocess.call([binary,])
                fp = open(TAGGED_RESULTS, "r")
                txt = fp.read()
                fp.close()
                match = ENERGY_PATTERN.search(txt)
                print("dir: %s, ilatvec: %2d, delta: %s" % (cart[coord], jj + 1, delta[kk]))
                print("energy:", energy[kk])
                if match:
                    energy[kk] = float(match.group("value"))
                else:
                    raise ValueError(f"No energy match found in {TAGGED_RESULTS}!")
            latderivs[jj][ii] = (energy[1] - energy[0]) / (2.0 * disp)
    return latderivs


def calculate_born(binary, disp, coords, specienames, species, origin,
                     latvecs, charge_restart):
    """Calculates effective Born charges by finite differences"""

    cart =('x','y','z')
    delta =('-','+')
    dipole = np.empty((2,3,), dtype=float)
    born = np.empty((len(coords), 3, 3), dtype=float)
    for iat in range(len(coords)):
        for ii in range(3):
            for jj in range(2):
                newcoords = np.array(coords)
                newcoords[iat][ii] += float(2 * jj - 1) * disp
                writegen("geo.gen", (specienames, species, newcoords, origin,
                                     latvecs))
                if (charge_restart):
                    shutil.copy2(charge_restart.name, CHARGES_FILE)
                subprocess.call([binary])
                fp = open(TAGGED_RESULTS, "r")
                txt = fp.read()
                fp.close()
                match = DIPOLE_PATTERN.search(txt)
                print("iat: %2d, dir: %s, delta: %s" % (iat, cart[ii],delta[jj]))
                if match:
                    dipole[jj][0] = float(match.group("dipoleX"))
                    dipole[jj][1] = float(match.group("dipoleY"))
                    dipole[jj][2] = float(match.group("dipoleZ"))
                    print("dipole:", dipole[jj][0], dipole[jj][1], dipole[jj][2])
                else:
                    raise ValueError(f"No dipole match found in {TAGGED_RESULTS}!")
            for jj in range(3):
                born[iat][ii][jj] = (dipole[0][jj] - dipole[1][jj]) / (2.0 * disp)
    return born


def calculate_charges(binary, disp, coords, specienames, species, origin,
                     latvecs, charge_restart):
    """Calculates charges by finite differences"""

    cart =('x','y','z')
    delta =('-','+')
    gross = np.empty((2,len(coords),), dtype=float)
    charges = np.empty((len(coords), 3, len(coords)), dtype=float)
    for iat in range(len(coords)):
        for ii in range(3):
            for jj in range(2):
                newcoords = np.array(coords)
                newcoords[iat][ii] += float(2 * jj - 1) * disp
                writegen("geo.gen", (specienames, species, newcoords, origin,
                                     latvecs))
                if (charge_restart):
                    shutil.copy2(charge_restart.name, CHARGES_FILE)
                subprocess.call([binary])
                fp = open({TAGGED_RESULTS}, "r")
                txt = fp.read()
                fp.close()
                match = CHARGE_PATTERN.search(txt)
                print("iat: %2d, dir: %s, delta: %s" % (iat + 1, cart[ii], delta[jj]))
                if match:
                    tmp = np.fromstring(match.group("values"), count=-1, dtype=float,
                                sep=" ")
                    gross[jj][:] = tmp
                else:
                    raise ValueError(f"No gross charges match found in {TAGGED_RESULTS}!")
            charges[iat][ii][:] = (gross[0] - gross[1]) / (2.0 * disp)
    return charges


def print_forces(forces, forces0, directions, atoms):
    """Prints calculates forces"""

    cart = ('x', 'y', 'z')
    dir_str = ("%s" * len(directions)) % tuple(cart[x] for x in directions)
    num = len(directions)
    print("Forces by finite differences ("+dir_str+"):")
    for ii, atforce in enumerate(forces):
        print("%3d:" % (atoms[ii] + 1), ("%25.12E" * num % tuple(atforce)))
    if forces0 is not None:
        forces1 = np.empty((forces.shape))
        for aa, atom in enumerate(atoms):
            for ii, coord in enumerate(directions):
                forces1[aa, ii] = forces0[atom, coord]
        print("Reference forces:")
        for ii, atforce in enumerate(forces1):
            print("%3d:" % (atoms[ii] + 1), ("%25.12E" * num) % tuple(atforce))
        print("Difference between obtained and reference forces:")
        diff = forces - forces1
        for ii, idiff in enumerate(diff):
            print("%3d:" % (atoms[ii] + 1), ("%25.12E" * num) % tuple(idiff))
        print("Max diff in any force component:")
        print("%25.12E" % (abs(diff).max(),))


def print_latderivs(latderivs, latderivs0, directions):
    """Prints calculated lattice derivatives."""

    cart = ('x', 'y', 'z')
    dir_str = ("%s" * len(directions)) % tuple(cart[x] for x in directions)
    num = len(directions)
    print("Lattice derivatives by finite differences ("+dir_str+"):")
    for ii in range(3):
        print(("%25.12E" * num) % tuple(latderivs[ii]))
    if latderivs0 is not None:
        latderivs1 = np.empty((latderivs.shape))
        for ii, coord in enumerate(directions):
            latderivs1[:, ii] = latderivs0[:, coord]
        print("Reference lattice derivatives:")
        for ii in range(3):
            print(("%25.12E" * num) % tuple(latderivs1[ii]))
        print("Difference between obtained and reference lattice derivatives:")
        diff = latderivs - latderivs1
        for ii in range(3):
            print(("%25.12E" * num) % tuple(diff[ii]))
        print("Max diff in any lattice derivatives component:")
        print("%25.12E" % (abs(diff).max(), ))


def print_born(born):
    """Prints Born effective charges"""

    print("Effective Born charges by finite differences of dipole moments:")
    for ii, atborn in enumerate(born):
        for jj in range(3):
            print(("%16.8F" * 3 % tuple(atborn[jj])))
        print("")

def print_charges(charge):
    """Prints charge derivatives wrt positions"""

    print("Atomic charge derivatives by finite differences:")
    for ii, atq in enumerate(charge):
        qat = atq.T
        for jj in range(len(qat)):
            print(("%16.8F" * 3 % tuple(-1.88972613392125187641 * qat[jj])))
        print("")

def directions_from_args(coords):
    """Determines the directions of the calculation"""
    dirstrs = {'x': 0, 'y': 1, 'z': 2}
    directions = set()
    for coord in coords:
        direction = dirstrs.get(coord)
        if direction is None:
            raise ValueError(f"'{coord}’ is no valid direction!")
        directions.add(direction)
    return sorted(directions)


def atoms_from_args(inp, maxlen):
    """Determines the atoms for the calculation"""
    atoms = []
    if inp is not None:
        inp_list = inp.split(",")
        for item in inp_list:
            if ":" in item:
                if "-" in item:
                    start, stop = item.split(":")
                    stop = int(stop) + maxlen
                else:
                    start, stop = item.split(":")
                for num in range(int(start), int(stop)+1):
                    atoms.append(num-1)
            else:
                atoms.append(int(item)-1)
        atoms.sort()
    else:
        atoms = list(range(maxlen))
    return atoms

if __name__ == "__main__":
    main()
