#!/usr/bin/env python3
#------------------------------------------------------------------------------#
#  DFTB+: general package for performing fast atomistic simulations            #
#  Copyright (C) 2006 - 2025  DFTB+ developers group                           #
#                                                                              #
#  See the LICENSE file for terms of usage and distribution.                   #
#------------------------------------------------------------------------------#
#
"""
Calculates the elastic stiffness tensor by finite differences of the stress tensor.
"""

import argparse
from derivtools import readgen, writegen, cart2frac, frac2cart, stress2latderivs, AA__BOHR, exists, create_temporary_copy
import numpy as np
import os
import re
import subprocess
import shutil

PRESURE_AU = 0.339893208050290E-13 # 1 Pa in atomic units of pressure
TAGGED_RESULTS = "results.tag" # Tagged data file

DESCRIPTION = """
Calculates Cijkl tensor in the Voight form, C_\alpha\beta, using the
specified DFTB+ binary and applying finite difference strain
displacements to the reference geometry. The geometry of the
configuration must be specified in a file called
'geo.gen.template'. The DFTB+ input file should include the geometry
from the file 'geo.gen'. The input file must specify the option for
writing a results tag file and the option for calculating the forces, as
this also outputs the stress tensor. The lattice vectors should be
fixed in the dftb_in.hsd file, but the atom positions (internal
coordinates) can either be clamped or relaxed depending on which is
required.
"""

STRESS_PATTERN = re.compile(r"stress\s*:[^:]*:[^:]*:3,3\n\s*(?P<value11>\S+)\s+(?P<value21>\S+)\s+(?P<value31>\S+)\n\s*(?P<value12>\S+)\s+(?P<value22>\S+)\s+(?P<value32>\S+)\n\s*(?P<value13>\S+)\s+(?P<value23>\S+)\s+(?P<value33>\S+)\n")

CHARGES_FILE = "charges.bin"

def main():
    """Main routine"""

    args = parse_arguments()
    specienames, species, coords, origin, latvecs = readgen("geo.gen.template")
    if latvecs is None:
        raise ValueError("Structure must be periodic for elastic constants!")
    if (exists(CHARGES_FILE)):
        chargeRestart = create_temporary_copy(CHARGES_FILE)
    else:
        chargeRestart = None
    disp = args.disp * AA__BOHR
    binary = args.binary
    print(f"DFTB+ BINARY: {binary}")

    if (exists(TAGGED_RESULTS)):
        os.remove(TAGGED_RESULTS)

    elastic = calculate_elastic( binary, disp, coords, latvecs,
                                 specienames, species, origin, chargeRestart)

    print_elastic(elastic)

    if (chargeRestart):
        shutil.copy2 (chargeRestart.name, CHARGES_FILE)

def parse_arguments():
    """Parses command line arguments"""

    parser = argparse.ArgumentParser(description=DESCRIPTION)

    msg = "Specify the displacement of the atoms (unit: ANGSTROM)"
    parser.add_argument("-d", "--displacement", type=float, dest="disp",
                        default=1e-5, help=msg)

    msg = "DFTB+ binary"
    parser.add_argument("binary", help=msg)

    args = parser.parse_args()
    return args


def calculate_elastic(binary, disp, coords, latvecs, specienames, species,
                        origin, chargeRestart):
    """Calculates elastic constants by finite differences"""

    # to hold stress tensor
    stress = np.empty((6, 2), dtype=float)
    # to hold resulting elastic tensor
    elastic = np.empty((6,6), dtype=float)

    # voight convention for 2 index to 1 of tensors
    voight = [[0,0],[1,1],[2,2],[1,2],[0,2],[0,1]]

    for ii in range(6):
        for kk in range(2):
            strain = np.zeros((3,3), dtype = float)
            for jj in range (3) :
                strain[jj][jj] = 1.0
            strain[voight[ii][0]][voight[ii][1]] += 0.5*float(2 * kk - 1) * disp
            strain[voight[ii][1]][voight[ii][0]] += 0.5*float(2 * kk - 1) * disp
            newcoords = np.array(coords)
            newcoords = cart2frac(latvecs,newcoords)
            newvecs = np.array(latvecs)
            newvecs = np.dot(newvecs,strain)
            newcoords = frac2cart(newvecs,newcoords)
            writegen("geo.gen", (specienames, species, newcoords, origin, newvecs))
            if (chargeRestart):
                shutil.copy2 (chargeRestart.name, CHARGES_FILE)
            subprocess.call([binary,])
            fp = open(TAGGED_RESULTS, "r")
            txt = fp.read()
            fp.close()
            match = STRESS_PATTERN.search(txt)
            if match:
                stress[0,kk] = float(match.group("value11"))
                stress[1,kk] = float(match.group("value22"))
                stress[2,kk] = float(match.group("value33"))
                stress[3,kk] = 0.5*float(match.group("value23"))
                stress[3,kk] += 0.5*float(match.group("value32"))
                stress[4,kk] = 0.5*float(match.group("value13"))
                stress[4,kk] += 0.5*float(match.group("value31"))
                stress[5,kk] = 0.5*float(match.group("value12"))
                stress[5,kk] += 0.5*float(match.group("value21"))
            else:
                raise ValueError(f"Stress tensor missing from {TAGGED_RESULTS}!")
            for jj in range(6):
                elastic[ii][jj] = (stress[jj][0] - stress[jj][1]) / (2.0 * disp)

    return elastic


def print_elastic(elastic):
    """Prints calculated elastic constants."""

    for ii in range(6):
        for jj in range(6):
            elastic[ii][jj] /= (PRESURE_AU * 1.0E9) # Convert to GPa
    print("\nElastic constants (GPa):")
    for ii in range(6):
        print (("%12.4f"*6) % tuple(elastic[ii]))


if __name__ == "__main__":
    main()
