# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import math
import re

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import clear_output, display
from ipywidgets import Dropdown, IntSlider, Text, interact, interact_manual, widgets
from plotly.subplots import make_subplots


def get_stats_cols(df):
    column_names_set = set(df.columns)
    q1 = "Q1" if "Q1" in column_names_set else "Q1 (approx)"
    q3 = "Q3" if "Q3" in column_names_set else "Q3 (approx)"

    if "Med" in column_names_set:
        med = "Med"
    elif "Median" in column_names_set:
        med = "Median"
    else:
        med = "Median (approx)"

    if "StdDev" in column_names_set:
        std = "StdDev"
    elif "Std" in column_names_set:
        std = "Std"
    else:
        std = "Std (approx)"

    return q1, med, q3, std


def display_box(df, x=None, **layout_args):
    if x is None:
        x = df.index

    q1, med, q3, std = get_stats_cols(df)

    fig = go.Figure()
    fig.add_trace(
        go.Box(
            x=x,
            q1=df[q1].tolist(),
            median=df[med].tolist(),
            q3=df[q3].tolist(),
            lowerfence=df["Min"].tolist(),
            upperfence=df["Max"].tolist(),
            sd=df[std].tolist(),
        )
    )

    fig.update_layout(**layout_args)
    fig.show()


def display_stats_scatter(df, x=None, **layout_args):
    if x is None:
        x = df.index

    fig = go.Figure()

    q1, med, q3, _ = get_stats_cols(df)
    col_names = [q1, med, q3, "Min", "Max"]

    for name in col_names:
        fig.add_trace(go.Scatter(x=x, y=df[name].tolist(), name=name))

    fig.update_layout(**layout_args)
    fig.show()


def display_table_per_rank(df):
    if df.empty:
        display(df)
        return

    rank_groups = df.groupby("Rank")

    def display_table(name):
        rank_df = rank_groups.get_group(name)
        rank_df = rank_df.drop(columns=["Rank"])
        display(rank_df)

    dropdown = Dropdown(
        options=rank_groups.groups.keys(),
        layout={"width": "max-content"},
        description="rank",
    )

    interact(display_table, name=dropdown)


def display_stats_per_operation(
    df, x=None, box=True, scatter=True, table=True, **layout_args
):
    if df.empty:
        display(df)
        return

    if x is None:
        x = df.index

    op_groups = df.groupby(x)

    def display_graphs(name):
        op_df = op_groups.get_group(name)
        if table:
            display(op_df.reset_index(drop=True).set_index("Rank"))
        if box:
            display_box(op_df, x=op_df["Rank"], **layout_args)
        if scatter:
            display_stats_scatter(op_df, x=op_df["Rank"], **layout_args)

    operations = list(op_groups.groups.keys())

    # Plot is not being displayed for the default dropdown value. If there is
    # only one element, do not create the dropdown. Otherwise, set it to the
    # second element before resetting to the first one.
    if len(operations) > 1:
        dropdown = Dropdown(
            options=operations, layout={"width": "max-content"}, value=operations[1]
        )
        interact(display_graphs, name=dropdown)
        dropdown.value = operations[0]
    elif len(operations) == 1:
        display_graphs(operations[0])


def display_stats_per_operation_device(
    df, device_df, box=True, scatter=True, table=True, **layout_args
):
    if df.empty:
        display(df)
        return

    op_groups = df.groupby(df.index)
    op_device_groups = device_df.groupby(device_df.index)

    def display_graphs(name, device):
        if device == "all":
            display_df = op_groups.get_group(name)
        else:
            if (name, device) in op_device_groups.groups:
                display_df = op_device_groups.get_group((name, device))
            else:
                # If the pair of name and device does not exist, silently return.
                # This could occur when we select a new name (e.g., CUDA operation, network metric)
                # and the selected device does not have data for that name.
                # The default selected device though will automatically be updated, as the device
                # dropdown gets updated.
                return
        if table:
            display(display_df.reset_index(drop=True).set_index("Rank"))
        if box:
            display_box(display_df, x=display_df["Rank"], **layout_args)
        if scatter:
            display_stats_scatter(display_df, x=display_df["Rank"], **layout_args)

    operations, devices = zip(*op_device_groups.groups.keys())

    operations = list(set(operations))
    devices = list(set(devices)) + ["all"]

    name_dropdown = Dropdown(options=operations, layout={"width": "max-content"})
    device_dropdown = Dropdown(options=devices, layout={"width": "max-content"})

    # Update function for name dropdown menu.
    def update_devices(*args):
        # We cache the previous value and assign it back later because when the device dropdown
        # values get updated, the selected device also gets updated. The selected device changes
        # even if it exists both in the previous and the updated device list.
        previous_device_dropdown_value = device_dropdown.value

        # Update the device dropdown options based on the selected operation/name. Operations/names
        # could be CUDA calls, network metrics, etc.
        device_dropdown.options = [
            device_key
            for name_key, device_key in op_device_groups.groups.keys()
            if name_dropdown.value == name_key
        ] + ["all"]

        if previous_device_dropdown_value in device_dropdown.options:
            device_dropdown.value = previous_device_dropdown_value

    name_dropdown.observe(update_devices, "value")

    interact(display_graphs, name=name_dropdown, device=device_dropdown)


def display_summary_graph(df, value_col, **layout_args):
    summary = df.groupby("Duration")[value_col].mean().reset_index()

    fig = go.Figure()

    if isinstance(value_col, str):
        value_col = [value_col]

    for col in value_col:
        fig.add_trace(
            go.Scatter(
                x=summary["Duration"].tolist(),
                y=summary[col].tolist(),
                mode="lines",
                name=col,
            )
        )

    fig.update_layout(**layout_args)
    fig.show()


def _get_heatmap_height(name_count, plot_count=1):
    name_count = max(name_count, 9)
    return (name_count * 27 + 110) * plot_count


def _split_numeric(input_string):
    # This function is meant to be used as the sorting key function to sort the
    # input string in a way that respects the numerical order of any numbers it
    # contains.
    # Extract numeric and non-numeric parts.
    parts = re.findall(r"\d+|\D+", input_string)
    return [int(part) if part.isdigit() else part for part in parts]


def display_heatmaps(
    df, types, xaxis_title, yaxis_title, zaxis_title, zmax=100, **layout_args
):
    unique_name_count = df["Name"].nunique()
    height = _get_heatmap_height(unique_name_count, len(types))

    fig = make_subplots(
        len(types), 1, subplot_titles=types, vertical_spacing=150 / height
    )

    for index, type in enumerate(types):
        fig.add_trace(
            go.Heatmap(
                x=df["Duration"].tolist(),
                y=df["Name"].tolist(),
                z=df[type].tolist(),
                showscale=False,
                zmax=zmax,
                zauto=False,
            ),
            index + 1,
            1,
        )

    fig.update_layout(height=height, **layout_args)
    fig.update_xaxes(title=xaxis_title)
    fig.update_yaxes(
        title=yaxis_title,
        categoryorder="array",
        categoryarray=sorted(df["Name"], key=_split_numeric, reverse=True),
        nticks=unique_name_count,
    )
    fig.update_traces(
        {"colorbar": {"title_text": zaxis_title}},
        showscale=True,
        row=0,
    )
    fig.update_traces(
        hovertemplate=f"{xaxis_title}: %{{x}}<br>{yaxis_title}: %{{y}}<br>{zaxis_title}: %{{z}}<extra></extra>"
    )
    fig.show()


def display_heatmap(
    df,
    value_col,
    xaxis_title,
    yaxis_title,
    zaxis_title,
    zmax=100,
    zaxis_tickformat=None,
    **layout_args,
):
    fig = make_subplots(1, 1, subplot_titles=[value_col], vertical_spacing=0.1)
    fig.add_trace(
        go.Heatmap(
            x=df["Duration"].tolist(),
            y=df["Name"].tolist(),
            z=df[value_col].tolist(),
            zmax=zmax,
            zauto=False,
            colorbar={"title": zaxis_title, "tickformat": zaxis_tickformat},
        )
    )

    unique_name_count = df["Name"].nunique()
    fig.update_layout(
        height=_get_heatmap_height(unique_name_count),
        **layout_args,
    )
    fig.update_xaxes(title=xaxis_title)
    fig.update_yaxes(
        title=yaxis_title,
        categoryorder="array",
        categoryarray=sorted(df["Name"], key=_split_numeric, reverse=True),
        nticks=unique_name_count,
    )
    fig.update_traces(
        hovertemplate=f"{xaxis_title}: %{{x}}<br>{yaxis_title}: %{{y}}<br>{zaxis_title}: %{{z}}<extra></extra>"
    )

    fig.show()


# Calculates the width of the bins to be used in a histogram.
def freedman_diaconis(df, column):
    sorted_vals = df[column].sort_values()
    qs = np.quantile(sorted_vals, [0, 0.25, 0.5, 0.75, 1])
    iqr = qs[3] - qs[1]
    n = sorted_vals.size
    width = 2 * (iqr / n ** (1 / 3))
    if width == 0:
        width = 1

    return int(width)


def display_histogram(
    df,
    xvalue_col,
    yvalue_col,
    xaxis_title,
    yaxis_title,
    event_type,
    **layout_args,
):
    if df.empty:
        return

    name_groups = df.groupby("Name")
    names = [f"all {event_type}"] + list(name_groups.groups.keys())

    dropdown = Dropdown(
        options=names,
        layout={"width": "max-content"},
        description="Name:",
    )

    bins_input = Text(
        placeholder="Enter a custom number of bins",
        description="Bins:",
        layout={"width": "500px"},
        disabled=False,
    )

    output = widgets.Output()

    def update_histogram(name, bins):
        with output:
            output.clear_output(wait=True)

        if name == f"all {event_type}":
            duration_df = df.groupby(["Duration"]).size().reset_index(name="Count")
            title = f"all {event_type} Duration Summary"
        else:
            duration_df = df.query("Name == @name")
            duration_df = (
                duration_df.groupby(["Duration"]).size().reset_index(name="Count")
            )
            title = f"{name} Duration Summary"

        data_max = duration_df["Duration"].max()
        data_min = duration_df["Duration"].min()

        # Plotly go requires a range of values to graph a bar, so if selection has one
        # data point, we handle this case by slightly offsetting the max and min values
        if data_max == data_min:
            data_min = data_min - 0.01
            data_max = data_max + 0.01

        # We set the max number of bins to 1000
        max_bins = 1000

        if bins == "":
            width = freedman_diaconis(duration_df, "Duration")
            bin_size = min((math.ceil((data_max - data_min) / width)), max_bins)
        else:
            try:
                bin_size = int(bins)
                if bin_size <= 0:
                    raise ValueError
            except ValueError:
                bins_input.value = ""
                bins_input.layout.border = "1px solid red"
                bins_input.placeholder = "Please enter a valid non-zero integer"
                return
            width = math.ceil((data_max - data_min) / bin_size)

        bins_input.layout.border = None
        bins_input.placeholder = "Enter a custom number of bins"

        fig = go.Figure(
            data=[
                go.Histogram(
                    x=duration_df[xvalue_col].tolist(),
                    y=duration_df[yvalue_col].tolist(),
                    autobinx=False,
                    xbins=dict(start=data_min, end=data_max, size=width),
                    histnorm="probability",
                    histfunc="sum",
                )
            ]
        )

        fig.update_layout(
            title_text=f"{title}, bins={bin_size}",
            xaxis_title_text=xaxis_title,
            yaxis_title_text=yaxis_title,
            **layout_args,
        )

        fig.update_traces(
            hovertemplate=f"{xaxis_title}: %{{x}}<br>{yaxis_title}: %{{y}}<extra></extra>"
        )
        fig.show()

    generate = interact_manual.options(manual_name="Generate graph")
    generate(update_histogram, name=dropdown, bins=bins_input)
    display(output)


def get_unique_index_values(grouped, preserve_order, partial_key=None):
    # Get the unique values for each index level in the grouped object,
    # optionally filtered by a given partial key.
    group_keys = list(grouped.groups.keys())

    if partial_key is not None:
        n = len(partial_key)
        # Ensure the input is a tuple for comparison with the group keys.
        key_tuple = tuple(partial_key)
        group_keys = [key for key in group_keys if key[:n] == key_tuple]

    if not group_keys:
        return []

    def get_unique_values(arr):
        if preserve_order:
            seen = set()
            return [x for x in arr if not (x in seen or seen.add(x))]
        else:
            return sorted(set(arr))

    if isinstance(group_keys[0], tuple):
        unique_values = list(zip(*group_keys))
        unique_values = [get_unique_values(level) for level in unique_values]
    else:
        unique_values = [get_unique_values(group_keys)]

    return unique_values


def _get_description_width(group_by_columns):
    max_length = max(len(str(option)) for option in group_by_columns)

    # Multiply character lengths by a factor to approximate pixel width
    # and add padding.
    width = (max_length * 7.5) + 10

    return width


def _initialize_dropdowns_per_column(
    index_grouped, group_by_columns, preserve_order=False
):
    # When a dropdown value is updated, the options for subsequent dropdowns
    # are dynamically updated to ensure valid selections.
    def update_dropdowns(change):
        changed_key = change["owner"].description
        current_index = group_by_columns.index(changed_key)

        selected_values = [
            dropdowns[group_by_columns[i]].value for i in range(current_index + 1)
        ]

        unique_values = get_unique_index_values(
            index_grouped, preserve_order, selected_values
        )
        # This won't happen from user selection, but when we manually set the
        # dropdown value to something (ex. None) that is not part of the group
        # keys.
        if not unique_values:
            return

        for i in range(current_index + 1, len(group_by_columns)):
            dropdown = dropdowns[group_by_columns[i]]
            dropdown.options = unique_values[i]

            if change.new in dropdown.options:
                dropdown.value = change.new

    unique_option_values = get_unique_index_values(index_grouped, preserve_order)
    description_width = _get_description_width(group_by_columns)

    dropdowns = {
        group_by_columns[i]: Dropdown(
            options=unique_option_values[i],
            style={"description_width": f"{description_width}px"},
            layout={"width": "max-content"},
            description=group_by_columns[i],
        )
        for i in range(len(group_by_columns))
    }

    for key in group_by_columns[:-1]:
        dropdowns[key].observe(update_dropdowns, "value")

    # We want to update the dropdowns based on the value of the first dropdown.
    # Since ipywidgets doesn't trigger the observer if the value is the same,
    # we temporarily set the value to None and then restore the original value.
    if "Rank" in dropdowns:
        original_value = dropdowns["Rank"].value
        dropdowns["Rank"].value = None
        dropdowns["Rank"].value = original_value

    return dropdowns, description_width


def display_stats_per_column(
    df, group_by_columns, post_process_fn=None, preserve_order=False
):
    if df.empty:
        display(df)
        return

    index_df = df.reset_index().set_index(group_by_columns)
    index_grouped = index_df.groupby(index_df.index.names, sort=(not preserve_order))

    def display_table(**selected_values):
        if len(group_by_columns) > 1:
            group_key = tuple(selected_values[key] for key in group_by_columns)
            if group_key not in index_grouped.groups:
                return
            display_df = index_grouped.get_group(group_key)
        else:
            group_key = selected_values[group_by_columns[0]]
            try:
                # If the group key is a single value, we need to wrap it in a tuple
                # to silence the future warning:
                # "FutureWarning: When grouping with a length-1 list-like, you
                # will need to pass a length-1 tuple to get_group in a
                # future version of pandas." This warning comes with pandas
                # versions >= 2.2.
                display_df = index_grouped.get_group((group_key,))
            except KeyError:
                # If the group key wrapped in a tuple is not found, it means
                # that the pandas version < 2.2 and does not support
                # this feature. In this case, we pass the group_key directly.
                display_df = index_grouped.get_group(group_key)

        display_df.reset_index(drop=True, inplace=True)

        if post_process_fn is not None:
            display_objects = post_process_fn(display_df)
            if isinstance(display_objects, list):
                for obj in display_objects:
                    display(obj)
            else:
                display(display_objects)
        else:
            display(display_df)

    dropdowns, _ = _initialize_dropdowns_per_column(
        index_grouped, group_by_columns, preserve_order
    )
    interact(display_table, **dropdowns)


def display_bar_chart(
    df, x, y, xaxis_title, yaxis_title, title, color="blue", categoryorder="trace"
):
    fig = px.histogram(df, x=x, y=y, barmode="group", color_discrete_sequence=[color])

    fig.update_layout(title=title)
    fig.update_xaxes(title=xaxis_title, categoryorder=categoryorder)
    fig.update_yaxes(title=yaxis_title)

    fig.update_traces(
        hovertemplate=f"{xaxis_title}: %{{x}}<br>{yaxis_title}: %{{y}}<extra></extra>"
    )

    fig.show()


def display_top_bottom_n(
    df, group_by_columns, x, y, xaxis_title, yaxis_title, **layout_args
):
    if df.empty:
        display(df)
        return

    index_df = df.reset_index().set_index(group_by_columns)
    index_grouped = index_df.groupby(index_df.index.names)

    def display_charts(**selected_values):
        n = selected_values.pop("n")

        if len(group_by_columns) > 1:
            group_key = tuple(selected_values[key] for key in group_by_columns)
            if group_key not in index_grouped.groups:
                return
        else:
            group_key = selected_values[group_by_columns[0]]

        display_df = index_grouped.get_group(group_key)
        num_values = display_df[x].nunique()

        if num_values < slider.max:
            slider.max = num_values
            clear_output(wait=True)
            n = slider.value

        bottom_kernels = display_df.nsmallest(n, y)
        top_kernels = display_df.nlargest(n, y)

        display_bar_chart(
            top_kernels,
            x,
            y,
            xaxis_title,
            yaxis_title,
            f"Top Ranges",
            "blue",
            "total descending",
        )
        display_bar_chart(
            bottom_kernels,
            x,
            y,
            xaxis_title,
            yaxis_title,
            f"Bottom Ranges",
            "red",
            "total ascending",
        )

    dropdowns, description_width = _initialize_dropdowns_per_column(
        index_grouped, group_by_columns
    )
    slider = IntSlider(
        description="N",
        min=1,
        max=10,
        step=1,
        value=5,
        style={"description_width": f"{description_width}px"},
    )

    interact(display_charts, **dropdowns, n=slider)
