/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.viatra.query.runtime.localsearch.planner.cost.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.eclipse.viatra.query.runtime.localsearch.matcher.integration.AbstractLocalSearchResultProvider;
import org.eclipse.viatra.query.runtime.localsearch.planner.cost.IConstraintEvaluationContext;
import org.eclipse.viatra.query.runtime.localsearch.planner.cost.ICostFunction;
import org.eclipse.viatra.query.runtime.matchers.backend.IQueryResultProvider;
import org.eclipse.viatra.query.runtime.matchers.context.IInputKey;
import org.eclipse.viatra.query.runtime.matchers.planning.helpers.FunctionalDependencyHelper;
import org.eclipse.viatra.query.runtime.matchers.planning.helpers.StatisticsHelper;
import org.eclipse.viatra.query.runtime.matchers.psystem.IQueryReference;
import org.eclipse.viatra.query.runtime.matchers.psystem.PConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.PVariable;
import org.eclipse.viatra.query.runtime.matchers.psystem.analysis.QueryAnalyzer;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.AggregatorConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.ExportedParameter;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.ExpressionEvaluation;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.Inequality;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.NegativePatternCall;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.PatternMatchCounter;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.TypeFilterConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.BinaryReflexiveTransitiveClosure;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.BinaryTransitiveClosure;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.ConstantValue;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.PositivePatternCall;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.TypeConstraint;
import org.eclipse.viatra.query.runtime.matchers.tuple.TupleMask;
import org.eclipse.viatra.query.runtime.matchers.util.Accuracy;
import org.eclipse.viatra.query.runtime.matchers.util.Preconditions;

public abstract class StatisticsBasedConstraintCostFunction
implements ICostFunction {
    protected static final double MAX_COST = 250.0;
    protected static final double DEFAULT_COST = 150.0;
    public static final double INVERSE_NAVIGATION_PENALTY_DEFAULT = 0.1;
    public static final double INVERSE_NAVIGATION_PENALTY_GENERIC = 0.01;
    public static final double EVAL_UNWIND_EXTENSION_FACTOR = 3.0;
    private final double inverseNavigationPenalty;

    public StatisticsBasedConstraintCostFunction(double inverseNavigationPenalty) {
        this.inverseNavigationPenalty = inverseNavigationPenalty;
    }

    public StatisticsBasedConstraintCostFunction() {
        this(0.1);
    }

    @Deprecated
    public long countTuples(IConstraintEvaluationContext input, IInputKey supplierKey) {
        return this.projectionSize(input, supplierKey, TupleMask.identity((int)supplierKey.getArity()), Accuracy.EXACT_COUNT).orElse(-1L);
    }

    public Optional<Long> projectionSize(IConstraintEvaluationContext input, IInputKey supplierKey, TupleMask groupMask, Accuracy requiredAccuracy) {
        long legacyCount = this.countTuples(input, supplierKey);
        return legacyCount < 0L ? Optional.empty() : Optional.of(legacyCount);
    }

    public Optional<Double> bucketSize(IQueryReference patternCall, IConstraintEvaluationContext input, TupleMask projMask) {
        IQueryResultProvider resultProvider = input.resultProviderRequestor().requestResultProvider(patternCall, null);
        if (resultProvider instanceof AbstractLocalSearchResultProvider) {
            double estimatedCost = ((AbstractLocalSearchResultProvider)resultProvider).estimateCost(projMask);
            return Optional.of(estimatedCost);
        }
        return resultProvider.estimateAverageBucketSize(projMask, Accuracy.APPROXIMATION);
    }

    @Override
    public double apply(IConstraintEvaluationContext input) {
        return this.calculateCost(input.getConstraint(), input);
    }

    protected double _calculateCost(ConstantValue constant, IConstraintEvaluationContext input) {
        return 0.0;
    }

    protected double _calculateCost(TypeConstraint constraint, IConstraintEvaluationContext input) {
        Collection<PVariable> freeMaskVariables = input.getFreeVariables();
        Collection<PVariable> boundMaskVariables = input.getBoundVariables();
        IInputKey supplierKey = (IInputKey)constraint.getSupplierKey();
        long arity = supplierKey.getArity();
        if (arity == 1L) {
            return this.calculateUnaryConstraintCost(constraint, input);
        }
        if (arity == 2L) {
            PVariable srcVariable = (PVariable)constraint.getVariablesTuple().get(0);
            PVariable dstVariable = (PVariable)constraint.getVariablesTuple().get(1);
            boolean isInverse = false;
            if (freeMaskVariables.contains(srcVariable) && boundMaskVariables.contains(dstVariable)) {
                isInverse = true;
            }
            double binaryExtendCost = this.calculateBinaryCost(supplierKey, srcVariable, dstVariable, isInverse, input);
            return isInverse ? binaryExtendCost + this.inverseNavigationPenalty : binaryExtendCost;
        }
        throw new UnsupportedOperationException("Cost calculation for arity " + arity + " is not implemented yet");
    }

    @Deprecated
    protected double calculateBinaryExtendCost(IInputKey supplierKey, PVariable srcVariable, PVariable dstVariable, boolean isInverse, long edgeCount, IConstraintEvaluationContext input) {
        throw new UnsupportedOperationException();
    }

    protected double calculateBinaryCost(IInputKey supplierKey, PVariable srcVariable, PVariable dstVariable, boolean isInverse, IConstraintEvaluationContext input) {
        Collection<PVariable> freeMaskVariables = input.getFreeVariables();
        PConstraint constraint = input.getConstraint();
        Optional<Long> edgeUpper = this.projectionSize(input, supplierKey, TupleMask.identity((int)2), Accuracy.BEST_UPPER_BOUND);
        Optional<Long> srcUpper = this.projectionSize(input, supplierKey, TupleMask.selectSingle((int)0, (int)2), Accuracy.BEST_UPPER_BOUND);
        Optional<Long> dstUpper = this.projectionSize(input, supplierKey, TupleMask.selectSingle((int)1, (int)2), Accuracy.BEST_UPPER_BOUND);
        if (freeMaskVariables.contains(srcVariable) && freeMaskVariables.contains(dstVariable)) {
            Double branchCount = edgeUpper.map(Long::doubleValue).orElse(srcUpper.map(Long::doubleValue).orElse(150.0) * dstUpper.map(Long::doubleValue).orElse(150.0));
            return branchCount;
        }
        Optional<Long> srcLower = this.projectionSize(input, supplierKey, TupleMask.selectSingle((int)0, (int)2), Accuracy.APPROXIMATION);
        Optional<Long> dstLower = this.projectionSize(input, supplierKey, TupleMask.selectSingle((int)1, (int)2), Accuracy.APPROXIMATION);
        List<Optional> nodeLower = Arrays.asList(srcLower, dstLower);
        List<Optional> nodeUpper = Arrays.asList(srcUpper, dstUpper);
        int from = isInverse ? 1 : 0;
        int to = isInverse ? 0 : 1;
        Optional costEstimate = Optional.empty();
        if (!freeMaskVariables.contains(srcVariable) && !freeMaskVariables.contains(dstVariable)) {
            costEstimate = StatisticsHelper.min(costEstimate, (double)0.9);
        }
        costEstimate = StatisticsHelper.min(costEstimate, edgeUpper.flatMap(edges -> ((Optional)nodeLower.get(from)).map(fromNodes -> fromNodes == 0L ? 0.0 : (double)edges.longValue() / (double)fromNodes.longValue())));
        if (this.navigatesThroughFunctionalDependencyInverse(input, constraint)) {
            costEstimate = StatisticsHelper.min((Optional)costEstimate, nodeUpper.get(to).flatMap(toNodes -> ((Optional)nodeLower.get(from)).map(fromNodes -> fromNodes == 0L ? 0.0 : (double)toNodes.longValue() / (double)fromNodes.longValue())));
        }
        if (!edgeUpper.isPresent()) {
            costEstimate = StatisticsHelper.min((Optional)costEstimate, nodeUpper.get(to).flatMap(toNodes -> ((Optional)nodeLower.get(from)).map(fromNodes -> fromNodes != 0L ? Math.max(1.0, (double)toNodes.longValue() / (double)fromNodes.longValue()) : 1.0)));
        }
        if (this.navigatesThroughFunctionalDependency(input, constraint)) {
            costEstimate = StatisticsHelper.min((Optional)costEstimate, (double)1.0);
        }
        return costEstimate.orElse(150.0);
    }

    protected boolean navigatesThroughFunctionalDependency(IConstraintEvaluationContext input, PConstraint constraint) {
        return this.navigatesThroughFunctionalDependency(input, constraint, input.getBoundVariables(), input.getFreeVariables());
    }

    protected boolean navigatesThroughFunctionalDependencyInverse(IConstraintEvaluationContext input, PConstraint constraint) {
        return this.navigatesThroughFunctionalDependency(input, constraint, input.getFreeVariables(), input.getBoundVariables());
    }

    protected boolean navigatesThroughFunctionalDependency(IConstraintEvaluationContext input, PConstraint constraint, Collection<PVariable> determining, Collection<PVariable> determined) {
        QueryAnalyzer queryAnalyzer = input.getQueryAnalyzer();
        Map functionalDependencies = queryAnalyzer.getFunctionalDependencies(Collections.singleton(constraint), false);
        Set impliedVariables = FunctionalDependencyHelper.closureOf(determining, (Map)functionalDependencies);
        return impliedVariables != null && impliedVariables.containsAll(determined);
    }

    protected double calculateUnaryConstraintCost(TypeConstraint constraint, IConstraintEvaluationContext input) {
        PVariable variable = (PVariable)constraint.getVariablesTuple().get(0);
        if (input.getBoundVariables().contains(variable)) {
            return 0.9;
        }
        return this.projectionSize(input, (IInputKey)constraint.getSupplierKey(), TupleMask.identity((int)1), Accuracy.APPROXIMATION).map(count -> 1.0 + (double)count.longValue()).orElse(150.0);
    }

    protected double _calculateCost(ExportedParameter exportedParam, IConstraintEvaluationContext input) {
        return 0.0;
    }

    protected double _calculateCost(TypeFilterConstraint exportedParam, IConstraintEvaluationContext input) {
        return 0.0;
    }

    protected double _calculateCost(PositivePatternCall patternCall, IConstraintEvaluationContext input) {
        ArrayList<Integer> boundPositions = new ArrayList<Integer>();
        List parameters = patternCall.getReferredQuery().getParameters();
        int i = 0;
        while (i < parameters.size()) {
            PVariable variable = patternCall.getVariableInTuple(i);
            if (input.getBoundVariables().contains(variable)) {
                boundPositions.add(i);
            }
            ++i;
        }
        TupleMask projMask = TupleMask.fromSelectedIndices((int)parameters.size(), boundPositions);
        return this.bucketSize((IQueryReference)patternCall, input, projMask).orElse(150.0);
    }

    protected double _calculateCost(ExpressionEvaluation evaluation, IConstraintEvaluationContext input) {
        double multiplier = evaluation.isUnwinding() && !input.getFreeVariables().isEmpty() ? 3.0 : 1.0;
        return this._calculateCost((PConstraint)evaluation, input) * multiplier;
    }

    protected double _calculateCost(Inequality inequality, IConstraintEvaluationContext input) {
        return this._calculateCost((PConstraint)inequality, input);
    }

    protected double _calculateCost(AggregatorConstraint aggregator, IConstraintEvaluationContext input) {
        return this._calculateCost((PConstraint)aggregator, input);
    }

    protected double _calculateCost(NegativePatternCall call, IConstraintEvaluationContext input) {
        return this._calculateCost((PConstraint)call, input);
    }

    protected double _calculateCost(PatternMatchCounter counter, IConstraintEvaluationContext input) {
        return this._calculateCost((PConstraint)counter, input);
    }

    protected double _calculateCost(BinaryTransitiveClosure closure, IConstraintEvaluationContext input) {
        return 150.0;
    }

    protected double _calculateCost(BinaryReflexiveTransitiveClosure closure, IConstraintEvaluationContext input) {
        return 150.0;
    }

    protected double _calculateCost(PConstraint constraint, IConstraintEvaluationContext input) {
        if (input.getFreeVariables().isEmpty()) {
            return 1.0;
        }
        return 150.0;
    }

    public double calculateCost(PConstraint constraint, IConstraintEvaluationContext input) {
        Preconditions.checkArgument((constraint != null ? 1 : 0) != 0, (String)"Set constraint value correctly");
        if (constraint instanceof ExportedParameter) {
            return this._calculateCost((ExportedParameter)constraint, input);
        }
        if (constraint instanceof TypeFilterConstraint) {
            return this._calculateCost((TypeFilterConstraint)constraint, input);
        }
        if (constraint instanceof ConstantValue) {
            return this._calculateCost((ConstantValue)constraint, input);
        }
        if (constraint instanceof PositivePatternCall) {
            return this._calculateCost((PositivePatternCall)constraint, input);
        }
        if (constraint instanceof TypeConstraint) {
            return this._calculateCost((TypeConstraint)constraint, input);
        }
        if (constraint instanceof ExpressionEvaluation) {
            return this._calculateCost((ExpressionEvaluation)constraint, input);
        }
        if (constraint instanceof Inequality) {
            return this._calculateCost((Inequality)constraint, input);
        }
        if (constraint instanceof AggregatorConstraint) {
            return this._calculateCost((AggregatorConstraint)constraint, input);
        }
        if (constraint instanceof NegativePatternCall) {
            return this._calculateCost((NegativePatternCall)constraint, input);
        }
        if (constraint instanceof PatternMatchCounter) {
            return this._calculateCost((PatternMatchCounter)constraint, input);
        }
        if (constraint instanceof BinaryTransitiveClosure) {
            return this._calculateCost((BinaryTransitiveClosure)constraint, input);
        }
        if (constraint instanceof BinaryReflexiveTransitiveClosure) {
            return this._calculateCost((BinaryReflexiveTransitiveClosure)constraint, input);
        }
        return this._calculateCost(constraint, input);
    }
}

