/*
 * Copyright (c) 2021, 2026 Contributors to the Eclipse Foundation
 *
 * This program and the accompanying materials are made
 * available under the terms of the Eclipse Public License 2.0
 * which is available at https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package org.eclipse.lsat.scheduler;

import static org.slf4j.LoggerFactory.getLogger;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.stream.Collectors;

import org.eclipse.core.runtime.IProgressMonitor;
import org.eclipse.lsat.common.scheduler.algorithm.BellmanFordScheduler;
import org.eclipse.lsat.common.scheduler.graph.Task;
import org.eclipse.lsat.common.scheduler.graph.TaskDependencyGraph;
import org.eclipse.lsat.common.scheduler.schedule.Schedule;
import org.eclipse.lsat.common.scheduler.schedule.ScheduledTask;
import org.eclipse.lsat.timing.calculator.MotionCalculatorExtension;
import org.eclipse.lsat.timing.util.TimingCalculator;
import org.slf4j.Logger;

import distributions.CalculationMode;
import lsat_graph.lsat_graphFactory;
import setting.Settings;

public class StochasticImpactAnalysis<T extends Task> {
    /**
     *
     */
    private static final int SCALE_CRITICALITY = 1;

    private static final Logger LOGGER = getLogger(StochasticImpactAnalysis.class);

    private final lsat_graphFactory graphFactory = lsat_graphFactory.eINSTANCE;

    private final TaskDependencyGraph<T> orgGraph;

    private final Settings settings;

    private final int sampleLength;

    public StochasticImpactAnalysis(TaskDependencyGraph<T> graph, Settings settings, int sampleLength) {
        this.orgGraph = graph;
        this.settings = settings;
        this.sampleLength = sampleLength;
    }

    public Schedule<T> transformModel(String name, HashMap<T, CollectedScheduleData> collectedData, IProgressMonitor monitor) throws Exception {
        LOGGER.debug("Starting stochastic analysis");
        var motionCalculator = MotionCalculatorExtension.getSelectedMotionCalculator();
        var timingCalculator = new TimingCalculator(settings, motionCalculator, CalculationMode.DISTRIBUTED);

        // add execution times to all nodes in the graph
        var addExecutionTimes = new AddExecutionTimes(timingCalculator);
        var scheduler = new BellmanFordScheduler<T>();
        var criticalPathAnalysis = new CriticalPathAnalysis<T>();
        var scheduledTasks = orgGraph.allNodesInTopologicalOrder();
        collectedData.clear();
        collectedData.putAll(scheduledTasks.stream().collect(Collectors.toMap(t -> t,
                t -> new CollectedScheduleData(sampleLength), (a, b) -> a, LinkedHashMap::new)));

        for (int sample = 0; sample < sampleLength; sample++) {
            if (monitor.isCanceled()) {
                return null;
            }
            var graph = addExecutionTimes.transformModel(orgGraph);
            var schedule = scheduler.createSchedule(graph);
            schedule.setName(name);
            var nodeMap = schedule.getNodes().stream().collect(Collectors.toMap(ScheduledTask::getTask, st->st));
            criticalPathAnalysis.transformModel(schedule, monitor);
            for (var entry: collectedData.entrySet()) {
                var node = entry.getKey();
                var collect = entry.getValue();
                var scheduledTask = nodeMap.get(node);
                collect.critical[sample] = isCritical(scheduledTask) ? 1 : 0;
                collect.startTime[sample] = scheduledTask.getStartTime().doubleValue();
                collect.duration[sample] = scheduledTask.getDuration().doubleValue();
            }
        }

        var result = createFinalSchedule(name, orgGraph, collectedData);
        LOGGER.debug("Finished stochastic analysis");
        return result;
    }

    private static <T extends Task> boolean isCritical(ScheduledTask<T> scheduledTask) {
        return scheduledTask.getAspects().stream().anyMatch(a -> a.getName().equals(CriticalPathAnalysis.CRITICAL));
    }

    public Schedule<T> createFinalSchedule(String name, TaskDependencyGraph<T> graph,
            Map<T, CollectedScheduleData> collectedData) throws Exception
    {
        var scheduler = new BellmanFordScheduler<T>();
        var schedule = scheduler.createSchedule(graph);
        schedule.setName(name);

        for (var entry: collectedData.entrySet()) {
            var node = entry.getKey();
            var collect = entry.getValue();
            var scheduledTask = schedule.getNodes().stream().filter(st -> st.getTask() == node).findFirst().get();
            var stat = new ScheduleStatistics(collect);
            // use statistics the determine mean start and end time
            double startTimeMean = stat.startTime.getMean();
            scheduledTask.setStartTime(bd(startTimeMean, scheduledTask.getStartTime().scale()));
            scheduledTask.setEndTime(bd(startTimeMean + stat.duration.getMean(), scheduledTask.getEndTime().scale()));
            addStochasticAnnotation(scheduledTask, stat, scheduledTask.getStartTime().scale());
        }
        return schedule;
    }

    @SuppressWarnings("unchecked")
    private void addStochasticAnnotation(ScheduledTask<T> at, ScheduleStatistics stat, int scale) {
        var realDist = new RealDistributionLibrary();
        var criticalPathCounter = stat.critical.getSum();
        var sampleLength = stat.critical.getN();

        var bounds = realDist.betaDistributionInverseCumulativeProbability(
                BigDecimal.valueOf(criticalPathCounter + 0.5),
                BigDecimal.valueOf(sampleLength - criticalPathCounter + 0.5));
        // Create a new object of type TargetClass
        var confidenceInterval = graphFactory.createBounds();
        confidenceInterval.setLower(bounds.get("lower"));
        confidenceInterval.setUpper(bounds.get("upper"));

        var ann = graphFactory.createStochasticAnnotation();
        ann.setName("stochasticSensitivity");
        ann.setCriticality(bd(criticalPathCounter / sampleLength, SCALE_CRITICALITY) );
        ann.setMean(bd(stat.startTime.getMean(), scale));
        ann.setMin(bd(stat.startTime.getMin(), scale));
        ann.setMax(bd(stat.startTime.getMax(), scale));
        var skewness = stat.startTime.getSkewness();
        if (!Double.valueOf(skewness).equals(Double.NaN)) {
            ann.setSkewness(bd(stat.startTime.getSkewness(), scale));
        }
        ann.setStandardDeviation(bd(stat.startTime.getStandardDeviation(), scale));
        ann.setConfidenceInterval(confidenceInterval);
        ann.getNodes().add((ScheduledTask<Task>)at);
        ((Schedule<Task>)at.getGraph()).getAspects().add(ann);
    }

    private static BigDecimal bd(double d, int scale) {
        return BigDecimal.valueOf(d).setScale(scale, RoundingMode.HALF_UP);
    }

}
