/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.internal.sql.engine.rel.agg;

import com.google.common.collect.ImmutableList;
import it.unimi.dsi.fastutil.ints.IntList;
import java.math.BigDecimal;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeSystem;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.ignite.internal.lang.IgniteStringFormatter;
import org.apache.ignite.internal.sql.engine.rel.IgniteProject;
import org.apache.ignite.internal.sql.engine.rel.IgniteRel;
import org.apache.ignite.internal.sql.engine.sql.fun.IgniteSqlOperatorTable;
import org.apache.ignite.internal.sql.engine.type.IgniteTypeFactory;
import org.apache.ignite.internal.sql.engine.util.Commons;
import org.jetbrains.annotations.TestOnly;

public class MapReduceAggregates {
    private static final Set<String> AGG_SUPPORTING_MAP_REDUCE = Set.of("COUNT", "MIN", "MAX", "SUM", "$SUM0", "EVERY", "SOME", "ANY", "AVG", "SINGLE_VALUE", "ANY_VALUE", "GROUPING");
    private static final MakeReduceExpr USE_INPUT_FIELD = (rexBuilder, input, args, typeFactory) -> rexBuilder.makeInputRef(input, args.getInt(0));

    private MapReduceAggregates() {
    }

    public static boolean canBeImplementedAsMapReduce(List<AggregateCall> aggCalls) {
        for (AggregateCall call : aggCalls) {
            SqlAggFunction agg = call.getAggregation();
            if (AGG_SUPPORTING_MAP_REDUCE.contains(agg.getName())) continue;
            return false;
        }
        return true;
    }

    public static IgniteRel buildAggregates(LogicalAggregate agg, AggregateRelBuilder builder, Mapping fieldMappingOnReduce) {
        int i;
        IgniteRel reduceInputNode;
        ArrayList<MapReduceAgg> mapReduceAggs = new ArrayList<MapReduceAgg>(agg.getAggCallList().size());
        int argumentOffset = agg.getGroupSet().cardinality();
        ArrayList<AggregateCall> mapAggCalls = new ArrayList<AggregateCall>(agg.getAggCallList().size());
        for (AggregateCall call : agg.getAggCallList()) {
            boolean canBeNull = agg.getGroupCount() == 0 || call.hasFilter();
            MapReduceAgg mapReduceAgg = MapReduceAggregates.createMapReduceAggCall(Commons.cluster(), call, argumentOffset, agg.getInput().getRowType(), canBeNull);
            argumentOffset += mapReduceAgg.reduceCalls.size();
            mapReduceAggs.add(mapReduceAgg);
            mapAggCalls.addAll(mapReduceAgg.mapCalls);
        }
        assert (mapAggCalls.size() >= agg.getAggCallList().size()) : IgniteStringFormatter.format((String)"The number of MAP aggregates is not correct. Original: {}\nMAP: {}", (Object[])new Object[]{agg.getAggCallList(), mapAggCalls});
        IgniteRel map = builder.makeMapAgg(agg.getCluster(), agg.getInput(), agg.getGroupSet(), (List<ImmutableBitSet>)agg.getGroupSets(), mapAggCalls);
        RelDataTypeFactory.Builder reduceType = new RelDataTypeFactory.Builder((RelDataTypeFactory)Commons.typeFactory());
        int groupByColumns = agg.getGroupSet().cardinality();
        boolean sameAggsForBothPhases = true;
        for (int i2 = 0; i2 < groupByColumns; ++i2) {
            List outputRowFields = agg.getRowType().getFieldList();
            RelDataType type = ((RelDataTypeField)outputRowFields.get(i2)).getType();
            reduceType.add("f" + reduceType.getFieldCount(), type);
        }
        RexBuilder rexBuilder = agg.getCluster().getRexBuilder();
        IgniteTypeFactory typeFactory = (IgniteTypeFactory)agg.getCluster().getTypeFactory();
        ArrayList<RexNode> reduceInputExprs = new ArrayList<RexNode>();
        for (int i3 = 0; i3 < map.getRowType().getFieldList().size(); ++i3) {
            RelDataType type = ((RelDataTypeField)map.getRowType().getFieldList().get(i3)).getType();
            RexInputRef ref = new RexInputRef(i3, type);
            reduceInputExprs.add((RexNode)ref);
        }
        boolean additionalProjectionsForReduce = false;
        int argOffset = 0;
        for (int i4 = 0; i4 < mapReduceAggs.size(); ++i4) {
            MapReduceAgg mapReduceAgg = (MapReduceAgg)mapReduceAggs.get(i4);
            int argIdx = groupByColumns + argOffset;
            for (int j = 0; j < mapReduceAgg.reduceCalls.size(); ++j) {
                RexNode projExpr = mapReduceAgg.makeReduceInputExpr.makeExpr(rexBuilder, (RelNode)map, IntList.of((int)argIdx), typeFactory);
                reduceInputExprs.set(argIdx, projExpr);
                if (mapReduceAgg.makeReduceInputExpr != USE_INPUT_FIELD) {
                    additionalProjectionsForReduce = true;
                }
                ++argIdx;
            }
            argOffset += mapReduceAgg.reduceCalls.size();
        }
        if (additionalProjectionsForReduce) {
            RelDataTypeFactory.Builder projectRow = new RelDataTypeFactory.Builder(agg.getCluster().getTypeFactory());
            for (int i5 = 0; i5 < reduceInputExprs.size(); ++i5) {
                RexNode rexNode = (RexNode)reduceInputExprs.get(i5);
                projectRow.add(String.valueOf(i5), rexNode.getType());
            }
            RelDataType projectRowType = projectRow.build();
            reduceInputNode = builder.makeProject(agg.getCluster(), (RelNode)map, reduceInputExprs, projectRowType);
        } else {
            reduceInputNode = map;
        }
        ArrayList<AggregateCall> reduceAggCalls = new ArrayList<AggregateCall>();
        ArrayList<AbstractMap.SimpleEntry<IntList, MakeReduceExpr>> projection = new ArrayList<AbstractMap.SimpleEntry<IntList, MakeReduceExpr>>(mapReduceAggs.size());
        for (MapReduceAgg mapReduceAgg : mapReduceAggs) {
            int i6 = 0;
            for (AggregateCall reduceCall : mapReduceAgg.reduceCalls) {
                reduceType.add("f" + i6 + "_" + reduceType.getFieldCount(), reduceCall.getType());
                reduceAggCalls.add(reduceCall);
                ++i6;
            }
            IntList reduceArgList = mapReduceAgg.argList;
            MakeReduceExpr projectionExpr = mapReduceAgg.makeReduceOutputExpr;
            projection.add(new AbstractMap.SimpleEntry<IntList, MakeReduceExpr>(reduceArgList, projectionExpr));
            if (projectionExpr == USE_INPUT_FIELD) continue;
            sameAggsForBothPhases = false;
        }
        RelDataType reduceTypeToUse = sameAggsForBothPhases ? agg.getRowType() : reduceType.build();
        assert (mapAggCalls.size() <= reduceAggCalls.size()) : IgniteStringFormatter.format((String)"The number of MAP/REDUCE aggregates is not correct. MAP: {}\nREDUCE: {}", (Object[])new Object[]{mapAggCalls, reduceAggCalls});
        ImmutableBitSet groupSetOnReduce = Mappings.apply((Mapping)fieldMappingOnReduce, (ImmutableBitSet)agg.getGroupSet());
        List<ImmutableBitSet> groupSetsOnReduce = agg.getGroupSets().stream().map(g -> Mappings.apply((Mapping)fieldMappingOnReduce, (ImmutableBitSet)g)).collect(Collectors.toList());
        IgniteRel reduce = builder.makeReduceAgg(agg.getCluster(), (RelNode)reduceInputNode, groupSetOnReduce, groupSetsOnReduce, reduceAggCalls, reduceTypeToUse);
        if (sameAggsForBothPhases) {
            return reduce;
        }
        ArrayList<Object> projectionList = new ArrayList<Object>(projection.size() + groupByColumns);
        for (i = 0; i < groupByColumns; ++i) {
            List outputRowFields = agg.getRowType().getFieldList();
            RelDataType relDataType = ((RelDataTypeField)outputRowFields.get(i)).getType();
            RexInputRef ref = new RexInputRef(i, relDataType);
            projectionList.add(ref);
        }
        for (Map.Entry entry : projection) {
            RexNode resultExpr = ((MakeReduceExpr)entry.getValue()).makeExpr(rexBuilder, (RelNode)reduce, (IntList)entry.getKey(), typeFactory);
            projectionList.add(resultExpr);
        }
        assert (projectionList.size() == agg.getRowType().getFieldList().size()) : IgniteStringFormatter.format((String)"Projection size does not match. Expected: {} but got {}", (Object[])new Object[]{agg.getRowType().getFieldList().size(), projectionList.size()});
        for (i = 0; i < projectionList.size(); ++i) {
            RexNode resultExpr = (RexNode)projectionList.get(i);
            List list = agg.getRowType().getFieldList();
            assert (resultExpr.getType().equals((Object)((RelDataTypeField)list.get(i)).getType())) : IgniteStringFormatter.format((String)"Type at position#{} does not match. Expected: {} but got {}.\nREDUCE aggregates: {}\nRow: {}.\nExpr: {}", (Object[])new Object[]{i, resultExpr.getType(), ((RelDataTypeField)list.get(i)).getType(), reduceAggCalls, list, resultExpr});
        }
        return new IgniteProject(agg.getCluster(), reduce.getTraitSet(), (RelNode)reduce, projectionList, agg.getRowType());
    }

    public static MapReduceAgg createMapReduceAggCall(RelOptCluster cluster, AggregateCall call, int reduceArgumentOffset, RelDataType input, boolean canBeNull) {
        String aggName = call.getAggregation().getName();
        assert (AGG_SUPPORTING_MAP_REDUCE.contains(aggName)) : "Aggregate does not support MAP/REDUCE " + String.valueOf(call);
        switch (aggName) {
            case "COUNT": {
                return MapReduceAggregates.createCountAgg(call, reduceArgumentOffset);
            }
            case "AVG": {
                return MapReduceAggregates.createAvgAgg(cluster, call, reduceArgumentOffset, input, canBeNull);
            }
            case "GROUPING": {
                return MapReduceAggregates.createGroupingAgg(call, reduceArgumentOffset);
            }
        }
        return MapReduceAggregates.createSimpleAgg(call, reduceArgumentOffset);
    }

    private static MapReduceAgg createCountAgg(AggregateCall call, int reduceArgumentOffset) {
        IntList argList = IntList.of((int)reduceArgumentOffset);
        AggregateCall sum0 = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM0, (boolean)call.isDistinct(), (boolean)call.isApproximate(), (boolean)call.ignoreNulls(), (List)ImmutableList.of(), (List)argList, (int)-1, null, (RelCollation)call.collation, (RelDataType)call.type, (String)("COUNT_" + reduceArgumentOffset + "_MAP_SUM"));
        MakeReduceExpr exprBuilder = (rexBuilder, input, args, typeFactory) -> {
            RexInputRef ref = rexBuilder.makeInputRef(input, args.getInt(0));
            return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.BIGINT), (RexNode)ref, true, false);
        };
        return new MapReduceAgg(argList, call, sum0, exprBuilder);
    }

    private static MapReduceAgg createSimpleAgg(AggregateCall call, int reduceArgumentOffset) {
        IntList argList = IntList.of((int)reduceArgumentOffset);
        AggregateCall reduceCall = AggregateCall.create((SqlAggFunction)call.getAggregation(), (boolean)call.isDistinct(), (boolean)call.isApproximate(), (boolean)call.ignoreNulls(), (List)ImmutableList.of(), (List)argList, (int)-1, (ImmutableBitSet)call.distinctKeys, (RelCollation)call.collation, (RelDataType)call.type, (String)call.name);
        return new MapReduceAgg(argList, call, reduceCall, USE_INPUT_FIELD);
    }

    private static MapReduceAgg createGroupingAgg(AggregateCall call, int reduceArgumentOffset) {
        IntList argList = IntList.of((int)reduceArgumentOffset);
        AggregateCall reduceCall = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SINGLE_VALUE, (boolean)call.isDistinct(), (boolean)call.isApproximate(), (boolean)call.ignoreNulls(), (List)ImmutableList.of(), (List)argList, (int)-1, null, (RelCollation)call.collation, (RelDataType)call.type, (String)("GROUPING" + reduceArgumentOffset));
        return new MapReduceAgg(argList, call, reduceCall, USE_INPUT_FIELD);
    }

    private static MapReduceAgg createAvgAgg(RelOptCluster cluster, AggregateCall call, int reduceArgumentOffset, RelDataType inputType, boolean canBeNull) {
        RelDataTypeFactory tf = cluster.getTypeFactory();
        RelDataTypeSystem typeSystem = tf.getTypeSystem();
        RelDataType fieldType = ((RelDataTypeField)inputType.getFieldList().get((Integer)call.getArgList().get(0))).getType();
        if (fieldType.getSqlTypeName() == SqlTypeName.NULL) {
            return MapReduceAggregates.createSimpleAgg(call, reduceArgumentOffset);
        }
        RelDataType mapSumType = typeSystem.deriveSumType(tf, fieldType);
        if (canBeNull) {
            mapSumType = tf.createTypeWithNullability(mapSumType, true);
        }
        AggregateCall mapSum0 = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM, (boolean)call.isDistinct(), (boolean)call.isApproximate(), (boolean)call.ignoreNulls(), (List)ImmutableList.of(), (List)call.getArgList(), (int)call.filterArg, null, (RelCollation)call.collation, (RelDataType)mapSumType, (String)("AVG_SUM" + reduceArgumentOffset));
        RelDataType mapCountType = tf.createSqlType(SqlTypeName.BIGINT);
        AggregateCall mapCount0 = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.COUNT, (boolean)call.isDistinct(), (boolean)call.isApproximate(), (boolean)call.ignoreNulls(), (List)ImmutableList.of(), (List)call.getArgList(), (int)call.filterArg, null, (RelCollation)call.collation, (RelDataType)mapCountType, (String)("AVG_COUNT" + reduceArgumentOffset));
        IntList reduceSumArgs = IntList.of((int)reduceArgumentOffset);
        RelDataType reduceSumType = typeSystem.deriveSumType(tf, mapSumType);
        if (canBeNull) {
            reduceSumType = tf.createTypeWithNullability(reduceSumType, true);
        }
        AggregateCall reduceSum0 = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM, (boolean)call.isDistinct(), (boolean)call.isApproximate(), (boolean)call.ignoreNulls(), (List)ImmutableList.of(), (List)reduceSumArgs, (int)-1, null, (RelCollation)call.collation, (RelDataType)reduceSumType, (String)("AVG_SUM" + reduceArgumentOffset));
        RelDataType reduceSumCountType = typeSystem.deriveSumType(tf, mapCount0.type);
        IntList reduceSumCountArgs = IntList.of((int)(reduceArgumentOffset + 1));
        AggregateCall reduceSumCount = AggregateCall.create((SqlAggFunction)SqlStdOperatorTable.SUM0, (boolean)call.isDistinct(), (boolean)call.isApproximate(), (boolean)call.ignoreNulls(), (List)ImmutableList.of(), (List)reduceSumCountArgs, (int)-1, null, (RelCollation)call.collation, (RelDataType)reduceSumCountType, (String)("AVG_SUM0" + reduceArgumentOffset));
        RelDataType finalReduceSumType = reduceSumType;
        MakeReduceExpr reduceInputExpr = (rexBuilder, input, args, typeFactory) -> {
            RexInputRef argExpr = rexBuilder.makeInputRef(input, args.getInt(0));
            if (args.getInt(0) == reduceArgumentOffset) {
                if (!SqlTypeUtil.equalSansNullability((RelDataType)finalReduceSumType, (RelDataType)argExpr.getType())) {
                    return rexBuilder.makeCast(finalReduceSumType, (RexNode)argExpr, true, false);
                }
                return argExpr;
            }
            return rexBuilder.makeCast(reduceSumCount.type, (RexNode)argExpr, true, false);
        };
        MakeReduceExpr reduceOutputExpr = (rexBuilder, input, args, typeFactory) -> {
            RexInputRef numeratorRef = rexBuilder.makeInputRef(input, args.get(0).intValue());
            RexInputRef denominatorRef = rexBuilder.makeInputRef(input, args.get(1).intValue());
            numeratorRef = rexBuilder.ensureType(mapSum0.type, (RexNode)numeratorRef, true);
            RelDataType resultType = typeFactory.decimalOf(call.type);
            int precision = resultType.getPrecision();
            int scale = resultType.getScale();
            RexLiteral p = rexBuilder.makeExactLiteral(BigDecimal.valueOf(precision), tf.createSqlType(SqlTypeName.INTEGER));
            RexLiteral s = rexBuilder.makeExactLiteral(BigDecimal.valueOf(scale), tf.createSqlType(SqlTypeName.INTEGER));
            RexNode sumDivCnt = rexBuilder.makeCall((SqlOperator)IgniteSqlOperatorTable.DECIMAL_DIVIDE, new RexNode[]{numeratorRef, denominatorRef, p, s});
            if (call.getType().getSqlTypeName() != SqlTypeName.DECIMAL) {
                sumDivCnt = rexBuilder.makeCast(call.getType(), sumDivCnt, false, false);
            }
            if (canBeNull) {
                RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO, denominatorRef.getType());
                RexNode eqZero = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, new RexNode[]{numeratorRef, zero});
                RexLiteral nullRes = rexBuilder.makeNullLiteral(call.getType());
                return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, new RexNode[]{eqZero, nullRes, sumDivCnt});
            }
            return sumDivCnt;
        };
        IntList argList = IntList.of((int)reduceArgumentOffset, (int)(reduceArgumentOffset + 1));
        return new MapReduceAgg(argList, List.of(mapSum0, mapCount0), reduceInputExpr, List.of(reduceSum0, reduceSumCount), reduceOutputExpr);
    }

    public static class MapReduceAgg {
        final IntList argList;
        final List<AggregateCall> mapCalls;
        final List<AggregateCall> reduceCalls;
        final MakeReduceExpr makeReduceInputExpr;
        final MakeReduceExpr makeReduceOutputExpr;

        MapReduceAgg(IntList argList, AggregateCall mapCalls, AggregateCall reduceCalls, MakeReduceExpr makeReduceOutputExpr) {
            this(argList, List.of(mapCalls), USE_INPUT_FIELD, List.of(reduceCalls), makeReduceOutputExpr);
        }

        MapReduceAgg(IntList argList, List<AggregateCall> mapCalls, MakeReduceExpr makeReduceInputExpr, List<AggregateCall> reduceCalls, MakeReduceExpr makeReduceOutputExpr) {
            this.argList = argList;
            this.mapCalls = mapCalls;
            this.reduceCalls = reduceCalls;
            this.makeReduceInputExpr = makeReduceInputExpr;
            this.makeReduceOutputExpr = makeReduceOutputExpr;
        }

        @TestOnly
        public AggregateCall getReduceCall() {
            return this.reduceCalls.get(0);
        }
    }

    public static interface AggregateRelBuilder {
        public IgniteRel makeMapAgg(RelOptCluster var1, RelNode var2, ImmutableBitSet var3, List<ImmutableBitSet> var4, List<AggregateCall> var5);

        public IgniteRel makeProject(RelOptCluster var1, RelNode var2, List<RexNode> var3, RelDataType var4);

        public IgniteRel makeReduceAgg(RelOptCluster var1, RelNode var2, ImmutableBitSet var3, List<ImmutableBitSet> var4, List<AggregateCall> var5, RelDataType var6);
    }

    @FunctionalInterface
    private static interface MakeReduceExpr {
        public RexNode makeExpr(RexBuilder var1, RelNode var2, IntList var3, IgniteTypeFactory var4);
    }
}

