# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import contextlib
from enum import Enum

import pandas as pd

from nsys_recipe.lib import cuda, data_utils, gpu_metrics, nvtx


class CompositeTable(Enum):
    CUDA_GPU = 0
    CUDA_GPU_GRAPH = 1
    CUDA_COMBINED = 2
    CUDA_COMBINED_KERNEL = 3
    CUDA_KERNEL = 4
    NVTX = 5
    NCCL = 6
    NIC = 7
    IB_SWITCH = 8
    MPI = 9
    UCX = 10
    GPU_METRICS = 11
    PERF_EVENTS = 12


def get_cuda_gpu_dict():
    return {
        "CUPTI_ACTIVITY_KIND_KERNEL": [
            "correlationId",
            "end",
            "globalPid",
            "start",
            "deviceId",
            "contextId",
            "greenContextId",
            "streamId",
        ],
        "CUPTI_ACTIVITY_KIND_MEMCPY": [
            "correlationId",
            "end",
            "globalPid",
            "start",
            "deviceId",
            "contextId",
            "greenContextId",
            "streamId",
        ],
        "CUPTI_ACTIVITY_KIND_MEMSET": [
            "correlationId",
            "end",
            "globalPid",
            "start",
            "deviceId",
            "contextId",
            "greenContextId",
            "streamId",
        ],
    }


def process_to_cuda_gpu_table(df_dict):
    """Generate CUDA GPU table.

    This function concatenates the kernel, memcpy, memset, and graph traces
    into one CUDA GPU table.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must at least contain
        one of the following:
        - "CUPTI_ACTIVITY_KIND_KERNEL"
        - "CUPTI_ACTIVITY_KIND_MEMCPY"
        - "CUPTI_ACTIVITY_KIND_MEMSET"
        - "CUPTI_ACTIVITY_KIND_GRAPH_TRACE"
    """
    gpu_table_names = [
        "CUPTI_ACTIVITY_KIND_KERNEL",
        "CUPTI_ACTIVITY_KIND_MEMCPY",
        "CUPTI_ACTIVITY_KIND_MEMSET",
        "CUPTI_ACTIVITY_KIND_GRAPH_TRACE",
    ]
    gpu_dfs = [df_dict[table] for table in gpu_table_names if table in df_dict]
    if not gpu_dfs:
        raise ValueError("No CUDA GPU tables found in the df dictionary.")

    return pd.concat(gpu_dfs, ignore_index=True)


def get_cuda_gpu_graph_dict():
    get_table_column_dict = get_cuda_gpu_dict()
    for key in [
        "CUPTI_ACTIVITY_KIND_KERNEL",
        "CUPTI_ACTIVITY_KIND_MEMCPY",
        "CUPTI_ACTIVITY_KIND_MEMSET",
    ]:
        get_table_column_dict[key].append("graphNodeId")

    get_table_column_dict["CUPTI_ACTIVITY_KIND_GRAPH_TRACE"] = [
        "correlationId",
        "end",
        "globalPid",
        "start",
        "deviceId",
        "contextId",
        "greenContextId",
        "streamId",
        "graphId",
    ]

    return get_table_column_dict


def get_cuda_combined_dict():
    table_column_dict = get_cuda_gpu_dict()
    table_column_dict.update(
        {
            "CUPTI_ACTIVITY_KIND_RUNTIME": [
                "correlationId",
                "end",
                "globalTid",
                "start",
                "nameId",
            ],
            "StringIds": ["id", "value"],
        }
    )

    return table_column_dict


def process_to_cuda_combined_table(df_dict):
    """Generate CUDA table containing runtime and GPU data.

    This function:
    1. Combines the runtime and GPU tables.
    2. Renames the "start" and "end" columns of the GPU dataframes to
        "gpu_start" and "gpu_end".
    3. Replaces the "nameId" column in the runtime table with its
        corresponding string value and renames it to "name".

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must at least contain
        one of the following:
        - "CUPTI_ACTIVITY_KIND_KERNEL": ["start", "end", "globalPid"]
        - "CUPTI_ACTIVITY_KIND_MEMCPY": ["start", "end", "globalPid"]
        - "CUPTI_ACTIVITY_KIND_MEMSET": ["start", "end", "globalPid"]
        - "CUPTI_ACTIVITY_KIND_GRAPH_TRACE": ["start", "end", "globalPid"]
        It must also contain:
        - "CUPTI_ACTIVITY_KIND_RUNTIME": ["start", "end", "globalTid", "nameId"]
        - "StringIds": ["id", "value"]
    """
    gpu_df = process_to_cuda_gpu_table(df_dict)
    cuda_df = cuda.combine_runtime_gpu_dfs(
        df_dict["CUPTI_ACTIVITY_KIND_RUNTIME"], gpu_df
    )

    return data_utils.replace_id_with_value(
        cuda_df, df_dict["StringIds"], "nameId", "name"
    )


def get_cuda_combined_kernel_dict():
    return {
        "CUPTI_ACTIVITY_KIND_RUNTIME": [
            "correlationId",
            "end",
            "globalTid",
            "start",
        ],
        "CUPTI_ACTIVITY_KIND_KERNEL": [
            "correlationId",
            "globalPid",
            "start",
            "end",
            "deviceId",
            "shortName",
            "mangledName",
            "demangledName",
        ],
        "StringIds": ["id", "value"],
    }


def process_to_cuda_combined_kernel_table(df_dict):
    """Generate refined CUDA table containing runtime and kernel data.

    This function performs the following:
    1. Combines the runtime and kernel tables.
    2. Renames the "start" and "end" columns of the GPU dataframes to
        "gpu_start" and "gpu_end".
    3. Replaces the "*Name" columns in the kernel table with its corresponding
        string value.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain:
        - "CUPTI_ACTIVITY_KIND_KERNEL": ["start", "end", "globalPid"]
            with additionally ["shortName", "mangledName", "demangledName"]
        - "CUPTI_ACTIVITY_KIND_RUNTIME": ["start", "end", "globalTid"]
        - "StringIds": ["id", "value"]
    """
    kernel_df = process_to_cuda_kernel_table(df_dict)
    return cuda.combine_runtime_gpu_dfs(
        df_dict["CUPTI_ACTIVITY_KIND_RUNTIME"],
        kernel_df,
    )


def get_cuda_kernel_dict():
    return {
        "CUPTI_ACTIVITY_KIND_KERNEL": [
            "correlationId",
            "globalPid",
            "start",
            "end",
            "deviceId",
            "shortName",
            "mangledName",
            "demangledName",
        ],
        "StringIds": ["id", "value"],
    }


def process_to_cuda_kernel_table(df_dict):
    """Generate refined CUDA GPU kernel table.

    This function replaces the "*Name" columns in the kernel table with its
    corresponding string value.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain:
        - "CUPTI_ACTIVITY_KIND_KERNEL": ["start", "end", "globalPid"]
            with additionally ["shortName", "mangledName", "demangledName"]
        - "StringIds": ["id", "value"]
    """
    kernel_df = df_dict["CUPTI_ACTIVITY_KIND_KERNEL"]
    str_df = df_dict["StringIds"]

    name_cols = ["shortName", "mangledName", "demangledName"]
    existing_cols = [col for col in name_cols if col in kernel_df.columns]
    for col in existing_cols:
        kernel_df = data_utils.replace_id_with_value(kernel_df, str_df, col)

    return kernel_df


def get_nvtx_dict():
    return {
        "NVTX_EVENTS": [
            "text",
            "start",
            "end",
            "textId",
            "globalTid",
            "endGlobalTid",
            "domainId",
            "eventType",
        ],
        "StringIds": ["id", "value"],
    }


def process_to_nvtx_table(df_dict):
    """Generate refined NVTX table.

    This function merges the "text" and "textId" fields into a single "text"
    field. Any integer representations in the "text" field are replaced with
    their corresponding string values.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain:
        - "NVTX_EVENTS": ["textId", "text"]
        - "StringIds": ["id", "value"]
    """
    return nvtx.combine_text_fields(df_dict["NVTX_EVENTS"], df_dict["StringIds"])


def process_to_nccl_table(df_dict):
    """Generate refined NCCL table.

    This function extracts NCCL data from the NVTX table. It merges the
    "text" and "textId" fields into a single "text" field. Any integer
    representations in the "text" field are replaced with their corresponding
    string values.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain:
        - "NVTX_EVENTS": ["textId", "text", "eventType", "domainId"]
        - "StringIds": ["id", "value"]
    """
    nccl_df = nvtx.filter_by_domain_name(df_dict["NVTX_EVENTS"], "NCCL")
    return nvtx.combine_text_fields(nccl_df, df_dict["StringIds"])


def get_nic_dict():
    return {
        "NIC_ID_MAP": ["globalId", "nicId"],
        "NET_NIC_METRIC": [
            "start",
            "end",
            "globalId",
            "metricsListId",
            "metricsIdx",
            "value",
        ],
        "TARGET_INFO_NETWORK_METRICS": ["metricsListId", "metricsIdx", "name"],
        "TARGET_INFO_NIC_INFO": ["GUID", "nicId", "name"],
    }


def process_to_nic_table(df_dict):
    """Generate refined Network Interface Controller (NIC) table.

    This function extracts NIC data from the tables and constructs dataframes to be used by
    later stages of data processing.
    This function:
    - joins the NIC data table with the NIC ID map table, so that we can correlate NIC metrics
    with their GUID and device name;
    - filters out NICs" entries that have zero valued metrics;
    - joins the existing NIC metrics dataframe with the `TARGET_INFO_NIC_INFO` table so that we
    can read the GUID value for NICs;
    - joins the NIC dataframes with the `TARGET_INFO_NETWORK_METRICS` table so that we can read
    the string names for each metric;
    - renames columns `name` from table `TARGET_INFO_NIC_INFO` to `nic_name` and column `name`
    from table `TARGET_INFO_NETWORK_METRICS` to `metric_name`.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain:
        - "NIC_ID_MAP": ["globalId", "nicId"]
        - "NET_NIC_METRIC": ["metricsListId", "metricsIdx", "globalId", "value"]
        - "TARGET_INFO_NETWORK_METRICS": ["metricsListId", "metricsIdx", "name"]
        - "TARGET_INFO_NIC_INFO": ["nicId", "GUID", "name"]
    """
    nic_df = df_dict["NET_NIC_METRIC"].join(
        df_dict["NIC_ID_MAP"].set_index("globalId"), on="globalId"
    )

    nic_metric_with_nic_info_df = nic_df.join(
        df_dict["TARGET_INFO_NIC_INFO"].set_index("nicId"), on="nicId"
    )
    nic_metric_with_nic_info_df["GUID"] = nic_metric_with_nic_info_df["GUID"].apply(hex)
    nic_metric_with_nic_info_df.rename(columns={"name": "nic_name"}, inplace=True)
    nic_metric_df = nic_metric_with_nic_info_df.merge(
        df_dict["TARGET_INFO_NETWORK_METRICS"], on=["metricsListId", "metricsIdx"]
    )
    nic_metric_df.rename(columns={"name": "metric_name"}, inplace=True)

    return nic_metric_df


def get_ib_switch_dict():
    return {
        "NET_IB_SWITCH_METRIC": [
            "start",
            "end",
            "globalId",
            "metricsListId",
            "metricsIdx",
            "value",
        ],
        "TARGET_INFO_NETWORK_METRICS": ["metricsListId", "metricsIdx", "name"],
    }


def process_to_ib_switch_table(df_dict):
    """Generate refined IB Switch table.

    This function extracts IB Switch data from the tables and constructs dataframes to be used
    by later stages of data processing.
    This function:
    - filters out IB Switches entries that have zero valued metrics;
    - joins the IB Switch metrics dataframes with the `TARGET_INFO_NETWORK_METRICS` table so
    that we can read the string names for each metric;
    - renames columns `globalId` from table `NET_IB_SWITCH_METRIC` to `GUID` and column `name`
    from table `TARGET_INFO_NETWORK_METRICS` to `metric_name`.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain:
        - "NET_IB_SWITCH_METRIC": ["metricsListId", "metricsIdx", "globalId", "value"]
        - "TARGET_INFO_NETWORK_METRICS": ["metricsListId", "metricsIdx", "name"]
    """
    ib_switch_df = df_dict["NET_IB_SWITCH_METRIC"].merge(
        df_dict["TARGET_INFO_NETWORK_METRICS"], on=["metricsListId", "metricsIdx"]
    )
    ib_switch_df["globalId"] = ib_switch_df["globalId"].apply(hex)
    ib_switch_df.rename(columns={"globalId": "GUID"}, inplace=True)
    ib_switch_df.rename(columns={"name": "metric_name"}, inplace=True)

    return ib_switch_df


def get_mpi_dict():
    return {
        "MPI_P2P_EVENTS": ["globalTid", "start", "end", "textId"],
        "MPI_START_WAIT_EVENTS": ["globalTid", "start", "end", "textId"],
        "MPI_OTHER_EVENTS": ["globalTid", "start", "end", "textId"],
        "MPI_COLLECTIVES_EVENTS": ["globalTid", "start", "end", "textId"],
        "StringIds": ["id", "value"],
    }


def process_to_mpi_table(df_dict):
    """Generate refined MPI table.

    This function combined different MPI tables into one. Any integer
    representations in the 'textId' field are replaced with their
    corresponding string values in the 'text' field.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain one of the
        following:
        - "MPI_P2P_EVENTS": ["start", "textId"]
        - "MPI_START_WAIT_EVENTS": ["start", "textId"]
        - "MPI_OTHER_EVENTS": ["start", "textId"]
        - "MPI_COLLECTIVES_EVENTS": ["start", "textId"]
        It must also contain:
        - "StringIds": ["id", "value"]
    """
    mpi_dfs = [
        df_dict["MPI_P2P_EVENTS"],
        df_dict["MPI_START_WAIT_EVENTS"],
        df_dict["MPI_OTHER_EVENTS"],
        df_dict["MPI_COLLECTIVES_EVENTS"],
    ]
    mpi_df = pd.concat(mpi_dfs).sort_values("start").reset_index(drop=True)

    return data_utils.replace_id_with_value(
        mpi_df, df_dict["StringIds"], "textId", "text"
    )


def get_ucx_dict():
    return {
        "UCP_SUBMIT_EVENTS": ["globalTid", "start", "end", "textId"],
        "UCP_PROGRESS_EVENTS": ["globalTid", "start", "end", "textId"],
        "UCP_EVENTS": ["globalTid", "start", "end", "textId"],
        "StringIds": ["id", "value"],
    }


def process_to_ucx_table(df_dict):
    """Generate refined MPI table.

    This function combines different UCX tables into one. Any integer
    representations in the 'textId' field are replaced with their
    corresponding string values in the 'text' field.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain one of the
        following:
        - "UCP_SUBMIT_EVENTS": ["start", "textId"]
        - "UCP_PROGRESS_EVENTS": ["start", "textId"]
        - "UCP_EVENTS": ["start", "textId"]
        It must also contain:
        - "StringIds": ["id", "value"]
    """
    ucx_dfs = [
        df_dict["UCP_SUBMIT_EVENTS"],
        df_dict["UCP_PROGRESS_EVENTS"],
        df_dict["UCP_EVENTS"],
    ]
    ucx_df = pd.concat(ucx_dfs).sort_values("start").reset_index(drop=True)

    return data_utils.replace_id_with_value(
        ucx_df, df_dict["StringIds"], "textId", "text"
    )


def get_gpu_metrics_dict():
    return {
        "GPU_METRICS": ["timestamp", "typeId", "metricId", "value"],
        "TARGET_INFO_GPU_METRICS": ["metricId", "metricName", "typeId"],
    }


def process_to_gpu_metrics_table(df_dict):
    """Generate refined GPU metrics table.

    This function pivots the original dataframe so that each unique metric name
    becomes a column. Any metric names that have changed in recent versions of
    Nsys are renamed to a known value. This includes:
    - SMs Active
    - SM Issue
    - Tensor Active
    - Unallocated Warps in Active SMs

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain:
        - "GPU_METRICS": ["timestamp", "typeId", "metricId", "value"]
        - "TARGET_INFO_GPU_METRICS": ["metricId", "metricName", "typeId"]
    """
    gpu_metrics_df = df_dict["GPU_METRICS"].merge(
        df_dict["TARGET_INFO_GPU_METRICS"], on=["metricId", "typeId"]
    )

    gpu_metrics_df = (
        gpu_metrics_df.pivot(
            index=["timestamp", "typeId"], columns="metricName", values="value"
        )
        .reset_index()
        .rename_axis(None, axis=1)
    )

    metrics_name_map = {}
    for metric_func, display_name in [
        (gpu_metrics.get_sm_active_name, "SMs Active"),
        (gpu_metrics.get_sm_issue_name, "SM Issue"),
        (gpu_metrics.get_tensor_active_name, "Tensor Active"),
        (gpu_metrics.get_unallocated_warps_name, "Unallocated Warps in Active SMs"),
    ]:
        with contextlib.suppress(ValueError):
            metrics_name_map[metric_func(gpu_metrics_df)] = display_name

    return gpu_metrics_df.rename(columns=metrics_name_map)


def get_perf_events_dict():
    return {
        "PERF_EVENT_SOC_OR_CPU_RAW_EVENT": [
            "start",
            "end",
            "vmId",
            "componentId",
            "eventId",
            "count",
        ],
        "TARGET_INFO_PERF_METRIC": ["id", "name"],
        "TARGET_INFO_COMPONENT": ["componentId", "name", "instance", "parentId"],
    }


def process_to_perf_events_table(df_dict):
    """
    Generate refined Perf Events table.

    This function merges different tables related to Perf Events data
    into one to get the refined table.
    This function:
    - merges the "TARGET_INFO_COMPONENT" table with itself to get CPU related to the
        component.
    - merges the "PERF_EVENT_SOC_OR_CPU_RAW_EVENT" table
        with the "TARGET_INFO_COMPONENT" to enrich the Perf Events table with
        the `componentType` and `cpu` columns.
    - merges the "PERF_EVENT_SOC_OR_CPU_RAW_EVENT" table
        with the "TARGET_INFO_PERF_METRIC" to enrich the Perf Events table with
        the `name` column.

    Parameters
    ----------
    df_dict : dict
        Dictionary mapping table names to dataframes. It must contain:
        -   "PERF_EVENT_SOC_OR_CPU_RAW_EVENT": [
                "start",
                "end",
                "vmId",
                "componentId",
                "eventId",
                "count"
            ]
        -   "TARGET_INFO_PERF_METRIC": [
                "id",
                "name"
            ]
        -   "TARGET_INFO_COMPONENT": [
                "componentId",
                "name",
                "instance",
                "parentId"
            ]

    Returns
    -------
    Dataframe representing the refined Perf Events table,
    which contains the following columns:
        - start: Event start timestamp (ns).
        - end: Event end timestamp (ns).
        - vmId: VM ID.
        - eventId: Event ID.
        - count: Counter data value.
        - componentType: Type of the component: Core, Cache, SocketN.
        - cpu: CPU number.
        - name: Name of the event.
    Example:
    start     end            vmId eventId count componentType cpu                name
     8912 3655504 281474976710656     113  3382          Core   0        INST_RETIRED
     8912 3655504 281474976710656     117    29         Cache   0      L2D_TLB_REFILL
    30224 3657904 281474976710656     4361 3622       Socket0 <NA>   SCF/cmem_rd_data
    """
    compnt_df = df_dict["TARGET_INFO_COMPONENT"].astype({"instance": pd.UInt32Dtype()})

    parent_compnt_df = compnt_df[["componentId", "instance"]].rename(
        columns={"componentId": "parentId", "instance": "parentInstance"}
    )

    compnt_df = compnt_df.merge(parent_compnt_df, on="parentId", how="left")
    compnt_df["cpu"] = compnt_df["instance"].combine_first(compnt_df["parentInstance"])
    compnt_df = compnt_df[["componentId", "name", "cpu"]].rename(
        columns={"name": "componentType"}
    )

    perf_metric_df = df_dict["TARGET_INFO_PERF_METRIC"].rename(
        columns={"id": "eventId"}
    )

    perf_events_df = (
        df_dict["PERF_EVENT_SOC_OR_CPU_RAW_EVENT"]
        .merge(compnt_df, on="componentId", how="left")
        .merge(perf_metric_df, on="eventId", how="left")
        .drop(columns=["componentId"])
    )

    return perf_events_df


def update_existing_keys(default_dict, extra_dict):
    """Merge "extra_dict" into "default_dict".

    Only keys that exist in "default_dict" are updated with the values
    from "extra_dict". Any keys in "extra_dict" that do not exist in
    "default_dict" are ignored.
    """
    if extra_dict is None:
        return default_dict

    merged_dict = default_dict.copy()
    for key, value in extra_dict.items():
        if key in default_dict:
            merged_dict[key] = list(set(value + default_dict[key]))
    return merged_dict


def get_table_column_dict(table, extra_dict=None):
    """Return a dictionary with recommended table and column names for the given
    composite table.

    Parameters
    ----------
    table : CompositeTable
        Value from the CompositeTable enum.
    extra_dict : dict, optional
        Dictionary mapping table names to additional column names to be read.
        Any tables that are not relevant will be ignored and will not be part
        of the final result.

    Returns
    -------
    table_column_dict : dict
        Dictionary mapping table names to lists of column names.
    """
    table_dict_map = {
        CompositeTable.CUDA_GPU: get_cuda_gpu_dict,
        CompositeTable.CUDA_GPU_GRAPH: get_cuda_gpu_graph_dict,
        CompositeTable.CUDA_COMBINED: get_cuda_combined_dict,
        CompositeTable.CUDA_COMBINED_KERNEL: get_cuda_combined_kernel_dict,
        CompositeTable.CUDA_KERNEL: get_cuda_kernel_dict,
        CompositeTable.NVTX: get_nvtx_dict,
        CompositeTable.NCCL: get_nvtx_dict,
        CompositeTable.NIC: get_nic_dict,
        CompositeTable.IB_SWITCH: get_ib_switch_dict,
        CompositeTable.MPI: get_mpi_dict,
        CompositeTable.UCX: get_ucx_dict,
        CompositeTable.GPU_METRICS: get_gpu_metrics_dict,
        CompositeTable.PERF_EVENTS: get_perf_events_dict,
    }

    if table not in table_dict_map:
        raise NotImplementedError("Invalid table name.")

    return update_existing_keys(table_dict_map[table](), extra_dict)


def get_refine_func(table):
    """Return the refinement function to be used for the given composite table.

    Parameters
    ----------
    table : CompositeTable
        Value from the CompositeTable enum.

    Returns
    -------
    func : function
        Refinement function that processes multiple tables and returns one
        dataframe.
    """
    table_func_map = {
        CompositeTable.CUDA_GPU: process_to_cuda_gpu_table,
        CompositeTable.CUDA_GPU_GRAPH: process_to_cuda_gpu_table,
        CompositeTable.CUDA_COMBINED: process_to_cuda_combined_table,
        CompositeTable.CUDA_COMBINED_KERNEL: process_to_cuda_combined_kernel_table,
        CompositeTable.CUDA_KERNEL: process_to_cuda_kernel_table,
        CompositeTable.NVTX: process_to_nvtx_table,
        CompositeTable.NCCL: process_to_nccl_table,
        CompositeTable.NIC: process_to_nic_table,
        CompositeTable.IB_SWITCH: process_to_ib_switch_table,
        CompositeTable.MPI: process_to_mpi_table,
        CompositeTable.UCX: process_to_ucx_table,
        CompositeTable.GPU_METRICS: process_to_gpu_metrics_table,
        CompositeTable.PERF_EVENTS: process_to_perf_events_table,
    }

    if table not in table_func_map:
        raise NotImplementedError("Invalid table name.")

    return table_func_map[table]
