# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import pandas as pd

from nsys_recipe.lib import overlap, summary


def find_column_name(df, candidate_columns):
    for col in candidate_columns:
        if col in df.columns:
            return col

    raise ValueError(f"None of the columns {candidate_columns} found in the dataframe.")


def get_sm_active_name(df):
    return find_column_name(
        df,
        [
            "SM Active [Throughput %]",
            "SMs Active [Throughput %]",
            "SM Active",
            "SMs Active",
        ],
    )


def get_sm_issue_name(df):
    return find_column_name(df, ["SM Issue [Throughput %]", "SM Issue"])


def get_tensor_active_name(df):
    return find_column_name(
        df,
        [
            "Tensor Active [Throughput %]",
            "Tensor Active",
            # For TU102/4/6 GPUs.
            "Tensor Active / FP16 Active",
            "Tensor Active / FP16 Active [Throughput %]",
        ],
    )


def get_unallocated_warps_name(df):
    return find_column_name(
        df,
        [
            "Unallocated Warps in Active SMs [Throughput %]",
            "Unallocated Warps in Active SMs",
        ],
    )


def _get_top_n_ranges(range_df, group_by_column, longest_n):
    duration_df = range_df.assign(duration=range_df["end"] - range_df["start"])
    range_grouped = duration_df.groupby(group_by_column)

    # Select the top N range groups based on the total duration of each group.
    duration_sum_df = range_grouped["duration"].sum().reset_index(name="duration_sum")
    top_n_groups = set(
        duration_sum_df.nlargest(longest_n, "duration_sum")[group_by_column]
    )

    return range_df[range_df[group_by_column].isin(top_n_groups)]


def map_ranges_to_metrics(range_df, metrics_df):
    # Since the 'map_overlapping_ranges' function requires ranges, we convert
    # the timestamp (a point in time) into start and end columns.
    metrics_start_end_df = metrics_df.assign(
        start=metrics_df["timestamp"], end=metrics_df["timestamp"]
    )

    # TODO(DTSP-18387): Ensure all ranges include GPU metrics data after
    # transitioning from binary inclusion to percentage overlap.
    return overlap.map_overlapping_ranges(
        range_df, metrics_start_end_df, key_df="df1", fully_contained=True
    )


def calculate_stats_by_range(metrics_df, range_df, longest_n=None):
    """Calculate metrics statistics for each unique name of ranges.

    Parameters
    ----------
    metrics_df : dataframe
        Dataframe containing the timestamps and the values of the metrics.
    range_df : dataframe
        Dataframe containing the start/end timestamps and the names of the
        ranges, for which statistics will be calculated.
    longest_n : int or None, optional
        Number of the top longest ranges to be considered. If None, all ranges
        will be considered.

    Returns
    -------
    stats_df : dataframe or None
        Dataframe containing the statistics, or None if no data is available.
    """
    dfs = []
    metrics_grouped = metrics_df.groupby("gpuId")
    range_grouped = range_df.groupby("gpuId")

    for gpu_id, metrics_id_df in metrics_grouped:
        if gpu_id not in range_grouped.groups:
            continue

        range_id_df = range_grouped.get_group(gpu_id)
        if longest_n is not None:
            range_id_df = _get_top_n_ranges(range_id_df, "name", longest_n)

        range_metrics_map = map_ranges_to_metrics(range_id_df, metrics_id_df)

        metrics_indices = []
        range_indices = []
        names = []

        range_id_grouped = range_id_df.groupby("name")
        for name, group in range_id_grouped:
            for index in group.index.tolist():
                if not index in range_metrics_map:
                    continue

                current_metrics_indices = range_metrics_map[index]
                metrics_indices.extend(current_metrics_indices)
                range_indices.extend([index] * len(current_metrics_indices))
                names.extend([name] * len(current_metrics_indices))

        contained_metrics_df = metrics_id_df.loc[metrics_indices]
        # We remove all columns that are not metric names.
        contained_metrics_df = contained_metrics_df.drop(
            columns=["timestamp", "typeId", "gpuId"]
        )
        contained_metrics_df["Group"] = range_indices
        contained_metrics_df["Name"] = names
        contained_metrics_grouped = contained_metrics_df.groupby(["Group", "Name"])

        # Generate statistics for all metrics based on the data grouped by
        # the same range instances.
        stats_df = summary.describe_columns(contained_metrics_grouped, "Metric Name")
        if stats_df is None:
            continue

        stats_df = stats_df.reset_index().drop(columns=["Group"])
        stats_df["GPU ID"] = gpu_id
        stats_df = stats_df.set_index(["Name", "GPU ID", "Metric Name"])

        dfs.append(stats_df)

    if not dfs:
        return None

    stats_df = pd.concat(dfs)
    # Aggregate the statistics from each range instance to get the final
    # statistics for each unique range name.
    return summary.aggregate_stats_df(stats_df)


def calculate_stats(metrics_df):
    """Calculate metrics statistics.

    Parameters
    ----------
    metrics_df : dataframe
        Dataframe containing the timestamps and the values of the metrics.

    Returns
    -------
    stats_df : dataframe or None
        Dataframe containing the statistics, or None if no data is available.
    """
    dfs = []

    metrics_grouped = metrics_df.groupby("gpuId")

    for gpu_id, metrics_id_df in metrics_grouped:
        metrics_id_df = metrics_id_df.drop(columns=["timestamp", "typeId", "gpuId"])
        stats_df = summary.describe_df(metrics_id_df, "Metric Name")

        stats_df = stats_df.reset_index()
        stats_df["GPU ID"] = gpu_id
        stats_df = stats_df.set_index(["GPU ID", "Metric Name"])

        dfs.append(stats_df)

    if not dfs:
        return None

    return pd.concat(dfs).round(1)
