/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
import java.util.HashMap;
import java.util.concurrent.Callable;
import org.apache.commons.lang3.tuple.MutableTriple;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;

public class ColumnEncoderBin
extends ColumnEncoder {
    public static final String MIN_PREFIX = "min";
    public static final String MAX_PREFIX = "max";
    public static final String NBINS_PREFIX = "nbins";
    private static final long serialVersionUID = 1917445005206076078L;
    protected int _numBin = -1;
    private double[] _binMins = null;
    private double[] _binMaxs = null;
    private double _colMins = -1.0;
    private double _colMaxs = -1.0;

    public ColumnEncoderBin() {
        super(-1);
    }

    public ColumnEncoderBin(int colID, int numBin) {
        super(colID);
        this._numBin = numBin;
    }

    public ColumnEncoderBin(int colID, int numBin, double[] binMins, double[] binMaxs) {
        super(colID);
        this._numBin = numBin;
        this._binMins = binMins;
        this._binMaxs = binMaxs;
    }

    public double getColMins() {
        return this._colMins;
    }

    public double getColMaxs() {
        return this._colMaxs;
    }

    public double[] getBinMins() {
        return this._binMins;
    }

    public double[] getBinMaxs() {
        return this._binMaxs;
    }

    @Override
    public void build(CacheBlock in) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (!this.isApplicable()) {
            return;
        }
        double[] pairMinMax = ColumnEncoderBin.getMinMaxOfCol(in, this._colID, 0, -1);
        this.computeBins(pairMinMax[0], pairMinMax[1]);
        if (DMLScript.STATISTICS) {
            Statistics.incTransformBinningBuildTime(System.nanoTime() - t0);
        }
    }

    @Override
    protected double getCode(CacheBlock in, int row) {
        if (this._binMins.length == 0 || this._binMaxs.length == 0) {
            LOG.warn((Object)"ColumnEncoderBin: applyValue without bucket boundaries, assign 1");
            return 1.0;
        }
        double inVal = in.getDoubleNaN(row, this._colID - 1);
        if (Double.isNaN(inVal) || inVal < this._binMins[0] || inVal > this._binMaxs[this._binMaxs.length - 1]) {
            return Double.NaN;
        }
        int ix = Arrays.binarySearch(this._binMaxs, inVal);
        return (ix < 0 ? Math.abs(ix + 1) : ix) + 1;
    }

    @Override
    protected ColumnEncoder.TransformType getTransformType() {
        return ColumnEncoder.TransformType.BIN;
    }

    private static double[] getMinMaxOfCol(CacheBlock in, int colID, int startRow, int blockSize) {
        double min = Double.POSITIVE_INFINITY;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = startRow; i < UtilFunctions.getEndIndex(in.getNumRows(), startRow, blockSize); ++i) {
            double inVal = in.getDouble(i, colID - 1);
            if (Double.isNaN(inVal)) continue;
            min = Math.min(min, inVal);
            max = Math.max(max, inVal);
        }
        return new double[]{min, max};
    }

    @Override
    public Callable<Object> getBuildTask(CacheBlock in) {
        return new ColumnBinBuildTask(this, in);
    }

    @Override
    public Callable<Object> getPartialBuildTask(CacheBlock in, int startRow, int blockSize, HashMap<Integer, Object> ret) {
        return new BinPartialBuildTask(in, this._colID, startRow, blockSize, ret);
    }

    @Override
    public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?> ret) {
        return new BinMergePartialBuildTask(this, ret);
    }

    public void computeBins(double min, double max) {
        if (this._binMins == null || this._binMaxs == null) {
            this._binMins = new double[this._numBin];
            this._binMaxs = new double[this._numBin];
        }
        for (int i = 0; i < this._numBin; ++i) {
            this._binMins[i] = min + (double)i * (max - min) / (double)this._numBin;
            this._binMaxs[i] = min + (double)(i + 1) * (max - min) / (double)this._numBin;
        }
    }

    @Override
    public void prepareBuildPartial() {
        this._colMins = -1.0;
        this._colMaxs = -1.0;
    }

    @Override
    public void buildPartial(FrameBlock in) {
        if (!this.isApplicable()) {
            return;
        }
        double[] pairMinMax = ColumnEncoderBin.getMinMaxOfCol(in, this._colID, 0, -1);
        this._colMins = pairMinMax[0];
        this._colMaxs = pairMinMax[1];
    }

    @Override
    protected ColumnEncoder.ColumnApplyTask<? extends ColumnEncoder> getSparseTask(CacheBlock in, MatrixBlock out, int outputCol, int startRow, int blk) {
        return new BinSparseApplyTask(this, in, out, outputCol);
    }

    @Override
    public void mergeAt(ColumnEncoder other) {
        if (other instanceof ColumnEncoderBin) {
            ColumnEncoderBin otherBin = (ColumnEncoderBin)other;
            assert (other._colID == this._colID);
            MutableTriple entry = new MutableTriple((Object)this._numBin, (Object)this._binMins[0], (Object)this._binMaxs[this._binMaxs.length - 1]);
            entry.middle = Math.min((Double)entry.middle, otherBin._binMins[0]);
            entry.right = Math.max((Double)entry.right, otherBin._binMaxs[otherBin._binMaxs.length - 1]);
            this._numBin = (Integer)entry.left;
            this._binMins = new double[this._numBin];
            this._binMaxs = new double[this._numBin];
            double min = (Double)entry.middle;
            double max = (Double)entry.right;
            for (int j = 0; j < this._numBin; ++j) {
                this._binMins[j] = min + (double)j * (max - min) / (double)this._numBin;
                this._binMaxs[j] = min + (double)(j + 1) * (max - min) / (double)this._numBin;
            }
            return;
        }
        super.mergeAt(other);
    }

    @Override
    public FrameBlock getMetaData(FrameBlock meta) {
        meta.ensureAllocatedColumns(this._binMaxs.length);
        meta.getColumnMetadata(this._colID - 1).setNumDistinct(this._numBin);
        for (int i = 0; i < this._binMaxs.length; ++i) {
            String sb = this._binMins[i] + "\u00b7" + this._binMaxs[i];
            meta.set(i, this._colID - 1, sb);
        }
        return meta;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        if (meta == null || this._binMaxs != null) {
            return;
        }
        int nbins = (int)meta.getColumnMetadata()[this._colID - 1].getNumDistinct();
        this._binMins = new double[nbins];
        this._binMaxs = new double[nbins];
        for (int i = 0; i < nbins; ++i) {
            String[] tmp = meta.get(i, this._colID - 1).toString().split("\u00b7");
            this._binMins[i] = Double.parseDouble(tmp[0]);
            this._binMaxs[i] = Double.parseDouble(tmp[1]);
        }
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        super.writeExternal(out);
        out.writeInt(this._numBin);
        out.writeBoolean(this._binMaxs != null);
        if (this._binMaxs != null) {
            for (int j = 0; j < this._binMaxs.length; ++j) {
                out.writeDouble(this._binMaxs[j]);
                out.writeDouble(this._binMins[j]);
            }
        }
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException {
        super.readExternal(in);
        this._numBin = in.readInt();
        boolean minmax = in.readBoolean();
        this._binMaxs = minmax ? new double[this._numBin] : null;
        double[] dArray = this._binMins = minmax ? new double[this._numBin] : null;
        if (!minmax) {
            return;
        }
        for (int j = 0; j < this._binMaxs.length; ++j) {
            this._binMaxs[j] = in.readDouble();
            this._binMins[j] = in.readDouble();
        }
    }

    private static class ColumnBinBuildTask
    implements Callable<Object> {
        private final ColumnEncoderBin _encoder;
        private final CacheBlock _input;

        protected ColumnBinBuildTask(ColumnEncoderBin encoder, CacheBlock input) {
            this._encoder = encoder;
            this._input = input;
        }

        @Override
        public Void call() throws Exception {
            this._encoder.build(this._input);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    private static class BinMergePartialBuildTask
    implements Callable<Object> {
        private final HashMap<Integer, ?> _partialMaps;
        private final ColumnEncoderBin _encoder;

        private BinMergePartialBuildTask(ColumnEncoderBin encoderBin, HashMap<Integer, ?> partialMaps) {
            this._partialMaps = partialMaps;
            this._encoder = encoderBin;
        }

        @Override
        public Object call() throws Exception {
            long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            double min = Double.POSITIVE_INFINITY;
            double max = Double.NEGATIVE_INFINITY;
            for (Object minMax : this._partialMaps.values()) {
                min = Math.min(min, ((double[])minMax)[0]);
                max = Math.max(max, ((double[])minMax)[1]);
            }
            this._encoder.computeBins(min, max);
            if (DMLScript.STATISTICS) {
                Statistics.incTransformBinningBuildTime(System.nanoTime() - t0);
            }
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    private static class BinPartialBuildTask
    implements Callable<Object> {
        private final CacheBlock _input;
        private final int _blockSize;
        private final int _startRow;
        private final int _colID;
        private final HashMap<Integer, Object> _partialMinMax;

        protected BinPartialBuildTask(CacheBlock input, int colID, int startRow, int blocksize, HashMap<Integer, Object> partialMinMax) {
            this._input = input;
            this._blockSize = blocksize;
            this._colID = colID;
            this._startRow = startRow;
            this._partialMinMax = partialMinMax;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public double[] call() throws Exception {
            long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            double[] minMax = ColumnEncoderBin.getMinMaxOfCol(this._input, this._colID, this._startRow, this._blockSize);
            HashMap<Integer, Object> hashMap = this._partialMinMax;
            synchronized (hashMap) {
                this._partialMinMax.put(this._startRow, minMax);
            }
            if (DMLScript.STATISTICS) {
                Statistics.incTransformBinningBuildTime(System.nanoTime() - t0);
            }
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<Start row: " + this._startRow + "; Block size: " + this._blockSize + ">";
        }
    }

    private static class BinSparseApplyTask
    extends ColumnEncoder.ColumnApplyTask<ColumnEncoderBin> {
        public BinSparseApplyTask(ColumnEncoderBin encoder, CacheBlock input, MatrixBlock out, int outputCol, int startRow, int blk) {
            super(encoder, input, out, outputCol, startRow, blk);
        }

        private BinSparseApplyTask(ColumnEncoderBin encoder, CacheBlock input, MatrixBlock out, int outputCol) {
            super(encoder, input, out, outputCol);
        }

        @Override
        public Object call() throws Exception {
            long t0;
            long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (this._out.getSparseBlock() == null) {
                return null;
            }
            ((ColumnEncoderBin)this._encoder).applySparse(this._input, this._out, this._outputCol, this._startRow, this._blk);
            if (DMLScript.STATISTICS) {
                Statistics.incTransformBinningApplyTime(System.nanoTime() - t0);
            }
            return null;
        }

        @Override
        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + ((ColumnEncoderBin)this._encoder)._colID + ">";
        }
    }
}

