/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv;

import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import org.apache.commons.lang3.concurrent.ConcurrentUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.PSWorker;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.utils.stats.ParamServStatistics;

public class LocalPSWorker
extends PSWorker
implements Callable<Void> {
    protected static final Log LOG = LogFactory.getLog((String)LocalPSWorker.class.getName());
    private static final long serialVersionUID = 5195390748495357295L;

    protected LocalPSWorker() {
    }

    public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg) {
        super(workerID, updFunc, freq, epochs, batchSize, ec, ps, nbatches, modelAvg);
    }

    @Override
    public String getWorkerName() {
        return String.format("Local worker_%d", this._workerID);
    }

    @Override
    public Void call() throws Exception {
        this.incWorkerNumber();
        try {
            long dataSize = this._features.getNumRows();
            int batchIter = (int)Math.ceil((double)dataSize / (double)this._batchSize);
            switch (this._freq) {
                case BATCH: {
                    this.computeBatch(dataSize, batchIter);
                    break;
                }
                case EPOCH: {
                    this.computeEpoch(dataSize, batchIter);
                    break;
                }
                case NBATCHES: {
                    this.computeNBatches(dataSize, batchIter);
                    break;
                }
                default: {
                    throw new DMLRuntimeException(String.format("%s not support update frequency %s", new Object[]{this.getWorkerName(), this._freq}));
                }
            }
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)String.format("%s: job finished.", this.getWorkerName()));
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException(String.format("%s failed", this.getWorkerName()), e);
        }
        return null;
    }

    private void computeEpoch(long dataSize, int batchIter) {
        for (int i = 0; i < this._epochs; ++i) {
            ListObject params = this.pullModel();
            Future<ListObject> accGradients = ConcurrentUtils.constantFuture(null);
            if (this._tpool == null) {
                this._tpool = CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
            }
            try {
                for (int j = 0; j < batchIter; ++j) {
                    ListObject gradients = this.computeGradients(params, dataSize, batchIter, i, j);
                    boolean localUpdate = j < batchIter - 1;
                    ListObject accGradientsPrev = (ListObject)accGradients.get();
                    accGradients = this._modelAvg ? ConcurrentUtils.constantFuture(null) : this._tpool.submit(() -> ParamservUtils.accrueGradients(accGradientsPrev, gradients, false, !localUpdate));
                    if (localUpdate | this._modelAvg) {
                        params = this.updateModel(params, gradients, i, j, batchIter);
                    }
                    this.accNumBatches(1);
                }
                this.pushGradients(this._modelAvg ? params : (ListObject)accGradients.get());
                if (!this._modelAvg) {
                    ParamservUtils.cleanupListObject(this._ec, "model");
                }
            }
            catch (InterruptedException | ExecutionException ex) {
                throw new DMLRuntimeException(ex);
            }
            this.accNumEpochs(1);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug((Object)String.format("%s: finished %d epoch.", this.getWorkerName(), i + 1));
        }
    }

    private void computeNBatches(long dataSize, int batchIter) {
        ListObject model = null;
        Future accGradients = ConcurrentUtils.constantFuture(null);
        for (int i = 0; i < this._epochs; ++i) {
            try {
                for (int j = 0; j < batchIter; ++j) {
                    boolean localUpdate;
                    boolean bl = localUpdate = j < batchIter;
                    if (j % this._nbatches == 0) {
                        model = this.pullModel();
                    }
                    ListObject gradients = this.computeGradients(model, dataSize, batchIter, i, j);
                    ListObject accGradientsPrev = (ListObject)accGradients.get();
                    accGradients = this._tpool.submit(() -> ParamservUtils.accrueGradients(accGradientsPrev, gradients, false, !localUpdate));
                    if (localUpdate | this._modelAvg) {
                        model = this.updateModel(model, gradients, i, j, batchIter);
                    }
                    this.accNumBatches(1);
                    if (j % this._nbatches == this._nbatches - 1 || j == batchIter - 1) {
                        this.pushGradients(this._modelAvg ? model : accGradients.get());
                        accGradients = ConcurrentUtils.constantFuture(null);
                    }
                    this.accNumBatches(1);
                }
            }
            catch (InterruptedException | ExecutionException ex) {
                throw new DMLRuntimeException(ex);
            }
            this.accNumEpochs(1);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug((Object)String.format("%s: finished %d epoch.", this.getWorkerName(), i + 1));
        }
    }

    private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int batchIter) {
        Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
        globalParams = this._ps.updateLocalModel(this._ec, gradients, globalParams);
        this.accLocalModelUpdateTime(tUpd);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("%s: local global parameter [size:%d kb] updated. [Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]", this.getWorkerName(), globalParams.getDataSize(), i + 1, this._epochs, j + 1, batchIter));
        }
        return globalParams;
    }

    private void computeBatch(long dataSize, int totalIter) {
        for (int i = 0; i < this._epochs; ++i) {
            for (int j = 0; j < totalIter; ++j) {
                ListObject globalParams = this.pullModel();
                ListObject gradients = this.computeGradients(globalParams, dataSize, totalIter, i, j);
                if (this._modelAvg) {
                    ListObject model = this.updateModel(globalParams, gradients, i, j, totalIter);
                    this.pushGradients(model);
                } else {
                    this.pushGradients(gradients);
                    ParamservUtils.cleanupListObject(this._ec, "model");
                }
                this.accNumBatches(1);
            }
            this.accNumEpochs(1);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug((Object)String.format("%s: finished %d epoch.", this.getWorkerName(), i + 1));
        }
    }

    private ListObject pullModel() {
        ListObject globalParams = this._ps.pull(this._workerID);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("%s: successfully pull the global parameters [size:%d kb] from ps.", this.getWorkerName(), globalParams.getDataSize() / 1024L));
        }
        return globalParams;
    }

    private void pushGradients(ListObject gradients) {
        this._ps.push(this._workerID, gradients);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("%s: successfully push the gradients [size:%d kb] to ps.", this.getWorkerName(), gradients.getDataSize() / 1024L));
        }
    }

    private ListObject computeGradients(ListObject params, long dataSize, int batchIter, int i, int j) {
        this._ec.setVariable("model", params);
        long begin = (long)j * this._batchSize + 1L;
        long end = Math.min((long)(j + 1) * this._batchSize, dataSize);
        Timing tSlic = DMLScript.STATISTICS ? new Timing(true) : null;
        MatrixObject bFeatures = ParamservUtils.sliceMatrix(this._features, begin, end);
        MatrixObject bLabels = ParamservUtils.sliceMatrix(this._labels, begin, end);
        this.accBatchIndexingTime(tSlic);
        this._ec.setVariable("features", bFeatures);
        this._ec.setVariable("labels", bLabels);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. [Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]", this.getWorkerName(), bFeatures.getDataSize() / 1024L + bLabels.getDataSize() / 1024L, begin, end, dataSize, i + 1, this._epochs, j + 1, batchIter));
        }
        Timing tGrad = DMLScript.STATISTICS ? new Timing(true) : null;
        this._inst.processInstruction(this._ec);
        this.accGradientComputeTime(tGrad);
        ListObject gradients = this._ec.getListObject(this._output.getName());
        ParamservUtils.cleanupData(this._ec, "features");
        ParamservUtils.cleanupData(this._ec, "labels");
        return gradients;
    }

    @Override
    protected void incWorkerNumber() {
        if (DMLScript.STATISTICS) {
            ParamServStatistics.incWorkerNumber();
        }
    }

    @Override
    protected void accLocalModelUpdateTime(Timing time) {
        if (DMLScript.STATISTICS) {
            ParamServStatistics.accLocalModelUpdateTime((long)time.stop());
        }
    }

    @Override
    protected void accBatchIndexingTime(Timing time) {
        if (DMLScript.STATISTICS) {
            ParamServStatistics.accBatchIndexingTime((long)time.stop());
        }
    }

    @Override
    protected void accGradientComputeTime(Timing time) {
        if (DMLScript.STATISTICS) {
            ParamServStatistics.accGradientComputeTime((long)time.stop());
        }
    }
}

