#!/usr/bin/env python3

"""Simple script to check usage of imports in Fortran files"""

import argparse
import fnmatch
import os
import re


_DESCRIPTION = """Checks the imports in DFTB+ Fortran file(s) for usage
WARNING: This script uses a simple heuristics to check import usage. It only checks wether the
names of imported functions/methods are present in the file or not."""

_PAT_USE_MODULE = re.compile(
    r"""^(?P<indent>[ ]*)use
    (?P<attrib>(?:\s*,\s*intrinsic)?)
    (?P<separator>[ ]*::[ ]*|[ ]*)
    (?P<name>\w+)
    (?P<rest>.*?(?:&[ ]*\n(?:[ ]*&)?.*?)*)\n
    """,
    re.VERBOSE | re.MULTILINE | re.IGNORECASE,
)

WORD_PATTERN = re.compile(r"""\w+""")

COMMENT_SYMBOLS = ["!!", "!>", "! "]


def main():
    """Main script driver."""
    args = parse_arguments()
    filenames = []
    if args.folders:
        filenames = get_files(args.folders)
    if args.files:
        filenames += args.files
    if args.file:
        filenames += args.file

    for fname in filenames:
        with open(fname, "r", encoding="utf-8") as file:
            file_txt = file.read()
            matches = list(_PAT_USE_MODULE.finditer(file_txt))
            matches.reverse()
            rest_list = []

            # collects all imported functions/methods and removes imports from file
            for match in matches:
                if match["rest"]:
                    rest_list.append(match["rest"])
                file_txt = file_txt[:match.start()] + file_txt[match.end():]
            file_txt = file_txt[match.start():] # remove lines before (first) import

            # remove comments
            file_txt_no_comments = ""
            for line in file_txt.splitlines():
                for symbol in COMMENT_SYMBOLS:
                    if symbol in line:
                        file_txt_no_comments += f"{line.split(symbol)[0]}\n"
                        break
                else:
                    file_txt_no_comments += f"{line}\n"
            file_txt = file_txt_no_comments

            # convert file into set()
            if args.case_sensitive:
                word_set = set(match.group() for match in WORD_PATTERN.finditer(file_txt))
            else:
                word_set = set(match.group().lower() for match in WORD_PATTERN.finditer(file_txt))

            import_set = get_import_set(rest_list, args.case_sensitive)

            unused_imports = import_set - word_set
            if unused_imports:
                print(f"{fname} contains the following unsued imports: ", unused_imports)
            else:
                print(f"{fname} OK")


def parse_arguments():
    """Parses the command line arguments"""
    parser = argparse.ArgumentParser(description=_DESCRIPTION,
                                     formatter_class=argparse.RawTextHelpFormatter)
    msg = "File(s) to process"
    parser.add_argument('file', nargs='*', metavar="FILE", default=None, help=msg)
    parser.add_argument("--files", nargs="+", metavar="FILE", help=msg)
    msg = "Folder(s) to process"
    parser.add_argument("--folders", nargs="+", metavar="FOLDER", help=msg)
    msg = "Case sensitive checking"
    parser.add_argument("-c", dest="case_sensitive", action="store_true", help=msg)
    args = parser.parse_args()
    if (not args.file) and (not args.folders and not args.files):
        parser.error("No Files/Folders specified!")
    return args


def get_files(folders):
    """Find all '*F90' or '*f90' files in folders"""
    file_list = []
    for folder in folders:
        for root, _, files in os.walk(folder):
            for file in files:
                if fnmatch.fnmatch(file, "*.[fF]90"):
                    file_list.append(os.path.join(root, file))
    return file_list


def get_import_set(rest_list, case_sensitive):
    """Creates set of imported methods and functions"""
    import_set = set()
    for rest in rest_list:
        if "only" not in rest.lower():
            continue
        _, imports = re.split(r"only\s*\:", rest, 1)
        imports_list = imports.split(",")
        for imp in imports_list:
            imp = re.sub(r"\&\s*&", "", imp).strip()
            if "operator" in imp.lower():
                continue
            if "assignment" in imp.lower():
                continue
            if "=>" in imp:
                if case_sensitive:
                    import_set.add(imp.split("=>")[0])
                else:
                    import_set.add(imp.split("=>")[0].lower())
            else:
                if case_sensitive:
                    import_set.add(imp)
                else:
                    import_set.add(imp.lower())
    return import_set


if __name__ == "__main__":
    main()
