#!/usr/bin/env python3

"""Simple script to sort module imports in Fortran files"""

import argparse
import fnmatch
import os
import re
import sys

from typing import Dict, List, Tuple

_DESCRIPTION = "Sorts the 'use' statements in DFTB+ Fortran file(s)"

_PAT_USE_MODULE = re.compile(
    r"""^(?P<indent>[ ]*)use
    (?P<attrib>(?:\s*,\s*intrinsic)?)
    (?P<separator>[ ]*::[ ]*|[ ]*)
    (?P<name>\w+)
    (?P<rest>.*?(?:&[ ]*\n(?:[ ]*&)?.*?)*)\n
    """,
    re.VERBOSE | re.MULTILINE | re.IGNORECASE,
)


def main():
    """Main script driver."""

    args = _parse_arguments()
    filenames = []
    if args.folders:
        filenames = _get_files(args.folders)
    if args.files:
        filenames += args.files
    if args.file:
        filenames += args.file

    for fname in filenames:
        with open(fname, "r", encoding="utf-8") as file:
            txt = file.read()
            blocks, output = _process_file_content(txt, fname)
        with open(fname, "w", encoding="utf-8") as file:
            file.write("\n".join(output) + "\n")
        if blocks > 1:
            print(f"{fname}: multiple blocks found!")


def _parse_arguments() -> argparse.Namespace:
    """Parses the command line arguments"""
    parser = argparse.ArgumentParser(description=_DESCRIPTION)
    msg = "File(s) to process"
    parser.add_argument('file', nargs='*', metavar="FILE", default=None, help=msg)
    parser.add_argument("--files", nargs="+", metavar="FILE", help=msg)
    msg = "Folder(s) to process"
    parser.add_argument("--folders", nargs="+", metavar="FOLDER", help=msg)
    args = parser.parse_args()
    if (not args.file) and (not args.folders and not args.files):
        parser.error("No Files/Folders specified!")
    return args


def _get_files(folders: List[str]) -> List[str]:
    """Find all '*F90' or '*f90' files in folders"""

    file_list = []
    for folder in folders:
        for root, _, files in os.walk(folder):
            for file in files:
                if fnmatch.fnmatch(file, "*.[fF]90"):
                    file_list.append(os.path.join(root, file))
    return file_list


def _process_file_content(txt: str, fname: str) -> Tuple[int, List[str]]:
    """Processes the content of a file."""

    output = []
    matches = [(match.group("name").lower(), match) for match in _PAT_USE_MODULE.finditer(txt)]
    lastpos = 0
    buffer = {}
    blocks = 0
    for name, match in matches:
        if match.start() != lastpos:
            if buffer:
                output += _get_sorted_modules(buffer)
                blocks += 1
            buffer = {}
            output.append(txt[lastpos : match.start()].rstrip())
        if name in buffer:
            _fatal_error(f"{fname}: multiple occurrences of module '{name}'!")
        buffer[name] = match
        lastpos = match.end()
    if buffer:
        output += _get_sorted_modules(buffer)
        blocks += 1
    output.append(txt[lastpos:].rstrip())
    return blocks, output


def _get_sorted_modules(modules: Dict[str, re.Match]) -> List[str]:
    """Returns the sorted representations of modules"""

    intrinsic_modules = []
    third_party_modules = []
    dftbplus_modules = []
    for name in modules:
        if name.startswith("dftbp_"):
            dftbplus_modules.append(name)
        elif name.startswith("iso_") or name == "mpi":
            intrinsic_modules.append(name)
        else:
            third_party_modules.append(name)
    intrinsic_modules.sort()
    third_party_modules.sort()
    dftbplus_modules.sort()

    output = []
    for name in intrinsic_modules + third_party_modules + dftbplus_modules:
        fields = modules[name]
        if fields["rest"]:
            output.append(_sort_rest(modules[name]))
        else:
            output.append(
                f"{fields['indent']}use{fields['attrib']}"
                f"{fields['separator']}{fields['name'].lower()}"
            )
    return output


def _sort_rest(fields: re.Match) -> str:
    """Sorts imported functions and methods in 'fields['rest']'"""
    if "only" not in fields["rest"].lower():
        return (
            f"{fields['indent']}use{fields['attrib']}"
            f"{fields['separator']}{fields['name'].lower()}{fields['rest']}"
        )

    _, imports = re.split(r"only\s*\:", fields["rest"], 1)
    imports_list = imports.split(",")
    imports_dict = {}
    for imp in imports_list:
        value_key = re.sub(r"\&\s*&", "", imp).strip()
        try:
            value, key = value_key.split("=>")
            imports_dict[key.strip()] = value.strip()
        except ValueError:
            imports_dict[value_key.strip()] = None
    sorted_imports = sorted(imports_dict.items(), key=lambda ii: ii[0].lower())

    output = (
        f"{fields['indent']}use{fields['attrib']}"
        f"{fields['separator']}{fields['name'].lower()}" + ", only :"
    )
    current_lenght = len(output)

    for key, item in sorted_imports:
        if item is not None:
            imp = f" {item} => {key},"
        else:
            imp = f" {key},"
        if (current_lenght + len(imp)) <= 100:
            current_lenght += len(imp)
            output += imp
        else:
            output += "&\n"
            imp = fields["indent"] + " " * 4 + f"&{imp}"
            current_lenght = len(imp)
            output += imp
    return output[:-1]


def _fatal_error(msg: str):
    """Prints an error message and stops"""

    sys.stderr.write(f"Error: {msg}\n")
    sys.exit(1)


if __name__ == "__main__":
    main()
