from abc import ABC, abstractmethod
from typing import List
from typing import Optional
from enum import Enum, auto
from nsys_cpu_stats.trace_utils import FrameDurations, TimeSlice, GPUMetric, CallStack, CPUConfig


class TraceLoaderSupport(Enum):
    TIMESLICES = auto()
    CORE_COUNT = auto()
    CPU_CONFIG = auto()
    GPU_METRICS = auto()
    ANALYSIS_DURATION = auto()
    CALLSTACKS = auto()
    CSWITCH_CALLSTACK_FRONT = auto()  # CSwitch callstacks are at the Front of a timeslice
    CSWITCH_CALLSTACK_BACK = auto()  # CSwitch callstacks are at the Back of a timeslice
    CSWITCH_CALLSTACK_SEPARATE = auto()  # Are cswitch callstacks detected/identified at load time?


class TraceLoaderRegions(Enum):
    CPU_FRAMETIMES = auto()
    GPU_FRAMETIMES = auto()
    CUDNN_KERNEL_LAUNCHES = auto()
    CUDNN_GPU_KERNELS = auto()


class TraceLoaderGPUMetrics(Enum):
    GPU_UTILISATION = auto()
    PCIE_BAR1_READS = auto()
    PCIE_BAR1_WRITES = auto()


class TraceLoaderEvents(Enum):
    DX12_API_CALLS = auto()
    CUDA_API_CALLS = auto()
    CUDA_GPU_KERNELS = auto()
    NVTX_MARKERS = auto()
    NVTX_GPU_MARKERS = auto()
    MPI_MARKERS = auto()
    PIX_MARKERS = auto()
    DXGKRNL_PROFILE_RANGE = auto()
    ETW_EVENTS = auto()
    DX12_GPU_WORKLOAD = auto()


class TraceLoaderInterface(ABC):
    def __init__(self):
        self.supported = []
        self.regions_supported = []
        self.derived_regions_supported = []
        self.pipeline_regions_supported = []
        self.gpu_metrics_supported = []
        self.events_supported = []
        self.gpu_metric_names = {}

    ####################################################
    #
    # Utility and helper functions
    #
    ####################################################
    @abstractmethod
    def determine_support(self):
        """Determine supported"""

    def is_supported(self, support: TraceLoaderSupport):
        """Query what is supported"""
        return support in self.supported

    def is_region_supported(self, support: TraceLoaderRegions):
        """Query which regions are supported"""
        return any(support in list for list in [self.regions_supported, self.derived_regions_supported, self.pipeline_regions_supported])

    def is_gpu_metric_supported(self, support: TraceLoaderGPUMetrics):
        """Query if a GPU metric is supported"""
        return support in self.gpu_metrics_supported

    def get_gpu_metric_name(self, metric: TraceLoaderGPUMetrics) -> str:
        return self.gpu_metric_names.get(metric, None)

    @abstractmethod
    def init_thread_name_dict(self):
        pass

    @abstractmethod
    def init_process_name_dict(self):
        pass

    @abstractmethod
    def is_graphics_workload(self,
                             start_time_ns: int,
                             end_time_ns: int,
                             target_pid: int):
        pass

    @abstractmethod
    def is_compute_workload(self,
                            start_time_ns: int,
                            end_time_ns: int,
                            target_pid: int):
        pass

    ####################################################
    #
    # Main interface to the loader.
    #
    ####################################################
    def get_timeslices(self,
                       start_time_ns: Optional[float] = None,
                       end_time_ns: Optional[float] = None,
                       target_pid: Optional[int] = None,
                       quiet: Optional[bool] = False) -> List[TimeSlice]:
        """Get the timeslices from the loader."""
        assert TraceLoaderSupport.TIMESLICES in self.supported  # Should always be supported
        return self._get_timeslices(start_time_ns, end_time_ns, target_pid, quiet)

    @abstractmethod
    def _get_timeslices(self,
                        start_time_ns: Optional[float] = None,
                        end_time_ns: Optional[float] = None,
                        target_pid: Optional[int] = None,
                        quiet: Optional[bool] = False) -> List[TimeSlice]:
        pass

    ####################################################
    #
    # Get the call chains - flattened
    #
    ####################################################
    def get_callstacks(self,
                       start_time_ns: float,
                       end_time_ns: float,
                       target_pid: int,
                       target_tid: Optional[int]) -> List[CallStack]:
        """Get the timeslices from the loader."""
        assert TraceLoaderSupport.CALLSTACKS in self.supported  # Should always be supported
        return self._get_callstacks(start_time_ns, end_time_ns, target_pid, target_tid)

    @abstractmethod
    def _get_callstacks(self,
                        start_time_ns: float,
                        end_time_ns: float,
                        target_pid: int,
                        target_tid: Optional[int]) -> List[CallStack]:
        pass

    ####################################################
    #
    # Get region durations
    #
    ####################################################
    def get_region_durations(self,
                             region_type: TraceLoaderRegions,
                             start_time_ns: Optional[float] = None,
                             end_time_ns: Optional[float] = None,
                             target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        """Get the region durations, if supported. Returns the average duration and the list of durations"""
        if region_type in self.regions_supported:
            return self._get_region_durations(region_type, start_time_ns, end_time_ns, target_pid)
        return 0, None

    @abstractmethod
    def _get_region_durations(self,
                              region_type: TraceLoaderRegions,
                              start_time_ns: Optional[float] = None,
                              end_time_ns: Optional[float] = None,
                              target_pid: Optional[int] = None,
                              ) -> (float, list[FrameDurations]):
        pass

    def get_derived_region_durations(self,
                                     region_type: TraceLoaderRegions,
                                     base_durations: list[FrameDurations],
                                     start_time_ns: Optional[float] = None,
                                     end_time_ns: Optional[float] = None,
                                     target_pid: Optional[int] = None,
                                     ) -> list[FrameDurations]:
        """Get region durations derived from a base frametime list. Returns the list of durations"""
        if region_type in self.derived_regions_supported:
            return self._get_derived_region_durations(region_type, base_durations, start_time_ns, end_time_ns, target_pid)
        return None

    @abstractmethod
    def _get_derived_region_durations(self,
                                      region_type: TraceLoaderRegions,
                                      base_durations: list[FrameDurations],
                                      start_time_ns: Optional[float] = None,
                                      end_time_ns: Optional[float] = None,
                                      target_pid: Optional[int] = None,
                                      ) -> list[FrameDurations]:
        pass

    ####################################################
    #
    # Get the CPU core count
    #
    ####################################################
    def get_core_count(self) -> int:
        """Get the CPU core count."""
        if TraceLoaderSupport.CORE_COUNT in self.supported:
            return self._get_core_count()
        return 0

    @abstractmethod
    def _get_core_count(self) -> int:
        pass

    ####################################################
    #
    # Get the CPU config
    #
    ####################################################
    def get_cpu_config(self) -> CPUConfig:
        """Get the CPU core count."""
        if TraceLoaderSupport.CPU_CONFIG in self.supported:
            return self._get_cpu_config()
        return None

    @abstractmethod
    def _get_cpu_config(self) -> int:
        pass

    ####################################################
    #
    # Get the duration of the analysis
    #
    ####################################################
    def get_analysis_duration(self) -> Optional[float]:
        """Get the duration of the Analysis."""
        if TraceLoaderSupport.ANALYSIS_DURATION in self.supported:
            return self._get_analysis_duration()
        return 0

    @abstractmethod
    def _get_analysis_duration(self) -> Optional[float]:
        pass

    ####################################################
    #
    # Decode a string
    #
    ####################################################
    @abstractmethod
    def get_string(self, sid) -> str:
        """Decode a string"""

    ####################################################
    #
    # Decode a string
    #
    ####################################################
    @abstractmethod
    def get_module_string(self, sid) -> str:
        """Decode a string for a module. Might need to strip folders"""

    ####################################################
    #
    # Get the Average GPU metric
    #
    ####################################################
    def get_average_gpu_metrics(self,
                                metric: TraceLoaderGPUMetrics,
                                start_time_ns: Optional[float] = None,
                                end_time_ns: Optional[float] = None) -> Optional[float]:
        """Get the average value for the provided GPU metric and time range."""
        if TraceLoaderSupport.GPU_METRICS in self.supported and metric in self.gpu_metrics_supported:
            return self._get_average_gpu_metrics(metric, start_time_ns, end_time_ns)
        return None

    @abstractmethod
    def _get_average_gpu_metrics(self,
                                 metric: TraceLoaderGPUMetrics,
                                 start_time_ns: Optional[float] = None,
                                 end_time_ns: Optional[float] = None) -> Optional[float]:
        pass

    ####################################################
    #
    # Get the ALL GPU metric
    #
    ####################################################
    def get_all_average_gpu_metrics(self,
                                    start_time_ns: Optional[float] = None,
                                    end_time_ns: Optional[float] = None) -> (dict, int):
        """Get All of the GPU metrics for the given time range and average them."""
        if TraceLoaderSupport.GPU_METRICS in self.supported:
            return self._get_all_average_gpu_metrics(start_time_ns=start_time_ns, end_time_ns=end_time_ns)
        return None, 0

    @abstractmethod
    def _get_all_average_gpu_metrics(self,
                                     start_time_ns: Optional[float] = None,
                                     end_time_ns: Optional[float] = None) -> (dict, int):
        pass

    ####################################################
    #
    # Get GPU metric list as frame durations
    #
    ####################################################
    def get_gpu_metric_frame_list(self,
                                  metric_type: TraceLoaderGPUMetrics,
                                  min_metric: Optional[float] = None,
                                  max_metric: Optional[float] = None,
                                  min_percent: Optional[float] = None,
                                  max_percent: Optional[float] = None,
                                  start_time_ns: Optional[float] = None,
                                  end_time_ns: Optional[float] = None) -> (List[GPUMetric], List[FrameDurations]):
        """Get the provided GPU metrics as a list of frame durations and a list of the actual GPU metrics."""
        if TraceLoaderSupport.GPU_METRICS in self.supported and metric_type in self.gpu_metrics_supported:
            return self._get_gpu_metric_frame_list(metric_type=metric_type,
                                                   min_metric=min_metric,
                                                   max_metric=max_metric,
                                                   min_percent=min_percent,
                                                   max_percent=max_percent,
                                                   start_time_ns=start_time_ns,
                                                   end_time_ns=end_time_ns)
        return None, None

    @abstractmethod
    def _get_gpu_metric_frame_list(self,
                                   metric_type: TraceLoaderGPUMetrics,
                                   min_metric: Optional[float] = None,
                                   max_metric: Optional[float] = None,
                                   min_percent: Optional[float] = None,
                                   max_percent: Optional[float] = None,
                                   start_time_ns: Optional[float] = None,
                                   end_time_ns: Optional[float] = None) -> (List[GPUMetric], List[FrameDurations]):
        pass

    ####################################################
    #
    # Get events for the given time range and tid/pid
    #
    ####################################################
    def get_events(self,
                   event_type: TraceLoaderEvents,
                   start_time_ns: int,
                   end_time_ns: int,
                   target_pid: int,
                   target_tid: int) -> (dict, dict):
        """Get the events for the given time range."""
        if event_type in self.events_supported:
            return self._get_events(event_type, start_time_ns=start_time_ns, end_time_ns=end_time_ns, target_pid=target_pid, target_tid=target_tid)
        return None, None

    @abstractmethod
    def _get_events(self,
                    event_type: TraceLoaderEvents,
                    start_time_ns: int,
                    end_time_ns: int,
                    target_pid: int,
                    target_tid: int) -> (dict, dict):
        pass

    ####################################################
    #
    # Get events for the given time range and tid/pid
    #
    ####################################################
    def get_ordered_events(self,
                           event_type: TraceLoaderEvents,
                           start_time_ns: int,
                           end_time_ns: int,
                           target_pid: int,
                           target_tid: int) -> List:
        """Get the events for the given time range."""
        if event_type in self.events_supported:
            return self._get_ordered_events(event_type, start_time_ns=start_time_ns, end_time_ns=end_time_ns, target_pid=target_pid, target_tid=target_tid)
        return None

    @abstractmethod
    def _get_ordered_events(self,
                            event_type: TraceLoaderEvents,
                            start_time_ns: int,
                            end_time_ns: int,
                            target_pid: int,
                            target_tid: int) -> List:
        pass
