/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.parfor.opt;

import java.util.HashMap;
import java.util.HashSet;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimator;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptNode;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptTree;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptimizerRuleBased;

public class OptimizerConstrained
extends OptimizerRuleBased {
    private static final Log LOG = LogFactory.getLog((String)OptimizerConstrained.class.getName());

    @Override
    public ParForProgramBlock.POptMode getOptMode() {
        return ParForProgramBlock.POptMode.CONSTRAINED;
    }

    @Override
    public boolean optimize(ParForStatementBlock sb, ParForProgramBlock pb, OptTree plan, CostEstimator est, ExecutionContext ec) {
        LOG.debug((Object)("--- " + this.getOptMode() + " OPTIMIZER -------"));
        this._cost = est;
        this._plan = plan;
        OptNode pn = this._plan.getRoot();
        if (pn.isLeaf()) {
            return true;
        }
        super.analyzeProblemAndInfrastructure(pn);
        LOG.debug((Object)(this.getOptMode() + " OPT: Optimize with local_max_mem=" + OptimizerConstrained.toMB(this._lm) + " and remote_max_mem=" + OptimizerConstrained.toMB(this._rm) + ")."));
        if (this._rnk <= 0 || this._rk <= 0) {
            LOG.warn((Object)(this.getOptMode() + " OPT: Optimize for inactive cluster (num_nodes=" + this._rnk + ", num_map_slots=" + this._rk + ")."));
        }
        OptNode.ExecType oldET = pn.getExecType();
        int oldK = pn.getK();
        pn.setSerialParFor();
        double M0a = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
        pn.setExecType(oldET);
        pn.setK(oldK);
        LOG.debug((Object)(this.getOptMode() + " OPT: estimated mem (serial exec) M=" + OptimizerConstrained.toMB(M0a)));
        HashMap<String, ParForProgramBlock.PartitionFormat> partitionedMatrices = new HashMap<String, ParForProgramBlock.PartitionFormat>();
        this.rewriteSetDataPartitioner(pn, ec.getVariables(), partitionedMatrices, OptimizerUtils.getLocalMemBudget(), true);
        double M0b = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
        this.rewriteRemoveUnnecessaryCompareMatrix(pn, ec);
        boolean flagLIX = super.rewriteSetResultPartitioning(pn, M0b, ec.getVariables());
        double M1 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
        LOG.debug((Object)(this.getOptMode() + " OPT: estimated new mem (serial exec) M=" + OptimizerConstrained.toMB(M1)));
        double M2 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn, Types.ExecType.CP);
        LOG.debug((Object)(this.getOptMode() + " OPT: estimated new mem (serial exec, all CP) M=" + OptimizerConstrained.toMB(M2)));
        double M3 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn, true);
        LOG.debug((Object)(this.getOptMode() + " OPT: estimated new mem (cond partitioning) M=" + OptimizerConstrained.toMB(M3)));
        ParForProgramBlock.PExecMode tmpmode = this.getPExecMode(pn);
        boolean flagRecompMR = this.rewriteSetExecutionStategy(pn, M0a, M1, M2, M3, flagLIX);
        if (pn.getExecType() == this.getRemoteExecType()) {
            if (M1 > this._rm && M3 <= this._rm) {
                this.rewriteSetDataPartitioner(pn, ec.getVariables(), partitionedMatrices, M3, true);
                M1 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
            }
            if (flagRecompMR) {
                this.rewriteSetOperationsExecType(pn, flagRecompMR);
                M1 = this._cost.getEstimate(CostEstimator.TestMeasure.MEMORY_USAGE, pn);
            }
            super.rewriteDataColocation(pn, ec.getVariables());
            super.rewriteSetPartitionReplicationFactor(pn, partitionedMatrices, ec.getVariables());
            super.rewriteSetExportReplicationFactor(pn, ec.getVariables());
            this.rewriteSetDegreeOfParallelism(pn, this._cost, ec.getVariables(), M1, false);
            this.rewriteSetTaskPartitioner(pn, false, flagLIX);
            this.rewriteSetFusedDataPartitioningExecution(pn, M1, flagLIX, partitionedMatrices, ec.getVariables(), tmpmode);
            HashSet<ParForStatementBlock.ResultVar> inplaceResultVars = new HashSet<ParForStatementBlock.ResultVar>();
            super.rewriteSetInPlaceResultIndexing(pn, this._cost, ec.getVariables(), inplaceResultVars, ec);
        } else {
            this.rewriteSetDegreeOfParallelism(pn, this._cost, ec.getVariables(), M1, false);
            this.rewriteSetTaskPartitioner(pn, false, false);
            HashSet<ParForStatementBlock.ResultVar> inplaceResultVars = new HashSet<ParForStatementBlock.ResultVar>();
            super.rewriteSetInPlaceResultIndexing(pn, this._cost, ec.getVariables(), inplaceResultVars, ec);
            super.rewriteInjectSparkLoopCheckpointing(pn);
            super.rewriteInjectSparkRepartition(pn, ec.getVariables());
            super.rewriteSetSparkEagerRDDCaching(pn, ec.getVariables());
        }
        this.rewriteSetResultMerge(pn, ec.getVariables(), true);
        super.rewriteSetRecompileMemoryBudget(pn);
        super.rewriteRemoveRecursiveParFor(pn, ec.getVariables());
        super.rewriteRemoveUnnecessaryParFor(pn);
        this._numEvaluatedPlans = 1L;
        return true;
    }

    @Override
    protected boolean rewriteSetDataPartitioner(OptNode n, LocalVariableMap vars, HashMap<String, ParForProgramBlock.PartitionFormat> partitionedMatrices, double thetaM, boolean constrained) {
        String initPlan = n.getParam(OptNode.ParamType.DATA_PARTITIONER);
        boolean blockwise = super.rewriteSetDataPartitioner(n, vars, partitionedMatrices, thetaM, constrained);
        if (!initPlan.equals(ParForProgramBlock.PDataPartitioner.UNSPECIFIED.name())) {
            ParForProgramBlock pfpb = (ParForProgramBlock)this._plan.getMappedProg(n.getID())[1];
            pfpb.setDataPartitioner(ParForProgramBlock.PDataPartitioner.valueOf(initPlan));
            LOG.debug((Object)(this.getOptMode() + " OPT: forced 'set data partitioner' - result=" + initPlan));
        }
        return blockwise;
    }

    @Override
    protected boolean rewriteSetExecutionStategy(OptNode n, double M0, double M, double M2, double M3, boolean flagLIX) {
        boolean ret = false;
        if (n.getExecType() != null && ConfigurationManager.isParallelParFor()) {
            ParForProgramBlock pfpb = (ParForProgramBlock)this._plan.getMappedProg(n.getID())[1];
            ParForProgramBlock.PExecMode mode = ParForProgramBlock.PExecMode.LOCAL;
            if (n.getExecType() == OptNode.ExecType.SPARK) {
                mode = ParForProgramBlock.PExecMode.REMOTE_SPARK;
            }
            ret = mode == ParForProgramBlock.PExecMode.REMOTE_SPARK && !n.isCPOnly();
            pfpb.setExecMode(mode);
            LOG.debug((Object)(this.getOptMode() + " OPT: forced 'set execution strategy' - result=" + mode));
        } else {
            ret = super.rewriteSetExecutionStategy(n, M0, M, M2, M3, flagLIX);
        }
        return ret;
    }

    @Override
    protected void rewriteSetDegreeOfParallelism(OptNode n, CostEstimator cost, LocalVariableMap vars, double M, boolean flagNested) {
        if (n.getK() > 0 && ConfigurationManager.isParallelParFor()) {
            ParForProgramBlock pfpb = (ParForProgramBlock)this._plan.getMappedProg(n.getID())[1];
            pfpb.setDegreeOfParallelism(n.getK());
            int remainParforK = OptimizerConstrained.getRemainingParallelismParFor(n.getK(), n.getK());
            int remainOpsK = OptimizerConstrained.getRemainingParallelismOps(this._lkmaxCP, n.getK());
            this.rAssignRemainingParallelism(n, remainParforK, remainOpsK);
            LOG.debug((Object)(this.getOptMode() + " OPT: forced 'set degree of parallelism' - result=(see EXPLAIN)"));
        } else {
            super.rewriteSetDegreeOfParallelism(n, cost, vars, M, flagNested);
        }
    }

    @Override
    protected void rewriteSetTaskPartitioner(OptNode pn, boolean flagNested, boolean flagLIX) {
        if (!pn.getParam(OptNode.ParamType.TASK_PARTITIONER).equals(ParForProgramBlock.PTaskPartitioner.UNSPECIFIED.name())) {
            ParForProgramBlock pfpb = (ParForProgramBlock)this._plan.getMappedProg(pn.getID())[1];
            pfpb.setTaskPartitioner(ParForProgramBlock.PTaskPartitioner.valueOf(pn.getParam(OptNode.ParamType.TASK_PARTITIONER)));
            Object tsExt = "";
            if (pn.getParam(OptNode.ParamType.TASK_SIZE) != null) {
                pfpb.setTaskSize(Integer.parseInt(pn.getParam(OptNode.ParamType.TASK_SIZE)));
                tsExt = (String)tsExt + "," + pn.getParam(OptNode.ParamType.TASK_SIZE);
            }
            LOG.debug((Object)(this.getOptMode() + " OPT: forced 'set task partitioner' - result=" + pn.getParam(OptNode.ParamType.TASK_PARTITIONER) + (String)tsExt));
        } else {
            if (pn.getParam(OptNode.ParamType.TASK_SIZE) != null) {
                LOG.warn((Object)"Cannot force task size without forcing task partitioner.");
            }
            super.rewriteSetTaskPartitioner(pn, flagNested, flagLIX);
        }
    }

    @Override
    protected void rewriteSetResultMerge(OptNode n, LocalVariableMap vars, boolean inLocal) {
        if (!n.getParam(OptNode.ParamType.RESULT_MERGE).equals(ParForProgramBlock.PResultMerge.UNSPECIFIED.name())) {
            ParForProgramBlock pfpb = (ParForProgramBlock)this._plan.getMappedProg(n.getID())[1];
            pfpb.setResultMerge(ParForProgramBlock.PResultMerge.valueOf(n.getParam(OptNode.ParamType.RESULT_MERGE)));
            LOG.debug((Object)(this.getOptMode() + " OPT: force 'set result merge' - result=" + n.getParam(OptNode.ParamType.RESULT_MERGE)));
        } else {
            super.rewriteSetResultMerge(n, vars, inLocal);
        }
    }

    protected void rewriteSetFusedDataPartitioningExecution(OptNode pn, double M, boolean flagLIX, HashMap<String, ParForProgramBlock.PartitionFormat> partitionedMatrices, LocalVariableMap vars, ParForProgramBlock.PExecMode emode) {
        if (emode == ParForProgramBlock.PExecMode.REMOTE_SPARK_DP) {
            ParForProgramBlock pfpb = (ParForProgramBlock)this._plan.getMappedProg(pn.getID())[1];
            if (partitionedMatrices.size() <= 0) {
                LOG.debug((Object)(this.getOptMode() + " OPT: unable to force 'set fused data partitioning and execution' - result=false"));
                return;
            }
            String moVarname = partitionedMatrices.keySet().iterator().next();
            ParForProgramBlock.PartitionFormat moDpf = partitionedMatrices.get(moVarname);
            MatrixObject mo = (MatrixObject)vars.get(moVarname);
            if (this.rIsAccessByIterationVariable(pn, moVarname, pfpb.getIterVar()) && (moDpf == ParForProgramBlock.PartitionFormat.ROW_WISE && mo.getNumRows() == this._N || moDpf == ParForProgramBlock.PartitionFormat.COLUMN_WISE && mo.getNumColumns() == this._N || moDpf._dpf == ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N && mo.getNumRows() <= this._N * (long)moDpf._N || moDpf._dpf == ParForProgramBlock.PDataPartitionFormat.COLUMN_BLOCK_WISE_N && mo.getNumColumns() <= this._N * (long)moDpf._N)) {
                pn.addParam(OptNode.ParamType.DATA_PARTITIONER, "REMOTE_SPARK(fused)");
                pfpb.setExecMode(ParForProgramBlock.PExecMode.REMOTE_SPARK_DP);
                int k = (int)Math.min(this._N, (long)this._rk2);
                pn.setK(k);
                pfpb.setDataPartitioner(ParForProgramBlock.PDataPartitioner.NONE);
                pfpb.enableColocatedPartitionedMatrix(moVarname);
                pfpb.setDegreeOfParallelism(k);
            }
            LOG.debug((Object)(this.getOptMode() + " OPT: force 'set fused data partitioning and execution' - result=true"));
        } else {
            super.rewriteSetFusedDataPartitioningExecution(pn, M, flagLIX, partitionedMatrices, vars);
        }
    }

    private ParForProgramBlock.PExecMode getPExecMode(OptNode pn) {
        ParForProgramBlock pfpb = (ParForProgramBlock)this._plan.getMappedProg(pn.getID())[1];
        return pfpb.getExecMode();
    }
}

