/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.rules;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.rules.SubstitutionRule;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.plan.rel.LogicalDedup;
import org.opensearch.sql.calcite.plan.rule.OpenSearchRuleConfig;
import org.opensearch.sql.calcite.utils.PPLHintUtils;
import org.opensearch.sql.calcite.utils.PlanUtils;
import org.opensearch.sql.opensearch.planner.rules.ImmutableDedupPushdownRule;
import org.opensearch.sql.opensearch.planner.rules.InterruptibleRelRule;
import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan;
import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan;
import org.opensearch.sql.utils.Utils;

@Value.Enclosing
public class DedupPushdownRule
extends InterruptibleRelRule<Config>
implements SubstitutionRule {
    private static final Logger LOG = LogManager.getLogger();

    protected DedupPushdownRule(Config config) {
        super(config);
    }

    @Override
    protected void onMatchImpl(RelOptRuleCall call) {
        LogicalDedup logicalDedup = (LogicalDedup)call.rel(0);
        LogicalProject projectWithExpr = (LogicalProject)call.rel(1);
        CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan)call.rel(2);
        this.apply(call, logicalDedup, projectWithExpr, scan);
    }

    protected void apply(RelOptRuleCall call, LogicalDedup dedup, LogicalProject project, CalciteLogicalIndexScan scan) {
        List dedupColumns = dedup.getDedupeFields();
        if (dedupColumns.stream().filter(rex -> rex.isA(SqlKind.INPUT_REF)).anyMatch(rex -> rex.getType().getSqlTypeName() == SqlTypeName.MAP || rex.getType().getSqlTypeName() == SqlTypeName.ARRAY)) {
            LOG.debug("Cannot pushdown the dedup since the dedup fields contains MAP/ARRAY type");
            return;
        }
        RelBuilder relBuilder = call.builder();
        relBuilder.push((RelNode)project);
        ArrayList<RexNode> targetProjections = new ArrayList<RexNode>();
        HashSet<Integer> dedupFieldsIndexSet = new HashSet<Integer>();
        for (RexNode dedupColumn : dedupColumns) {
            if (dedupColumn instanceof RexInputRef) {
                RexInputRef ref = (RexInputRef)dedupColumn;
                targetProjections.add(dedupColumn);
                dedupFieldsIndexSet.add(ref.getIndex());
                continue;
            }
            LOG.warn("The dedup column {} is illegal.", (Object)dedupColumn);
            return;
        }
        IntStream.range(0, project.getProjects().size()).boxed().filter(index -> !dedupFieldsIndexSet.contains(index)).map(arg_0 -> ((RelBuilder)relBuilder).field(arg_0)).forEach(targetProjections::add);
        relBuilder.project(targetProjections);
        LogicalProject targetChildProject = (LogicalProject)relBuilder.peek();
        if (targetChildProject.getNamedProjects().stream().limit(dedupColumns.size()).anyMatch(pair -> Utils.resolveNestedPath((String)((String)pair.getValue()), scan.getOsIndex().getFieldTypes()) != null)) {
            return;
        }
        List<Integer> newGroupByList = IntStream.range(0, dedupColumns.size()).boxed().toList();
        relBuilder.aggregate(relBuilder.groupKey((Iterable)relBuilder.fields(newGroupByList)), new RelBuilder.AggCall[]{relBuilder.literalAgg((Object)dedup.getAllowedDuplication())});
        PPLHintUtils.addIgnoreNullBucketHintToAggregate((RelBuilder)relBuilder);
        LogicalAggregate aggregate = (LogicalAggregate)relBuilder.build();
        assert (aggregate.getGroupSet().asList().equals(newGroupByList)) : "The group set of aggregate should be exactly the same as the generated group list";
        CalciteLogicalIndexScan newScan = (CalciteLogicalIndexScan)scan.pushDownAggregate((Aggregate)aggregate, (Project)targetChildProject);
        if (newScan != null) {
            call.transformTo((RelNode)newScan.copyWithNewSchema(dedup.getRowType()));
            PlanUtils.tryPruneRelNodes((RelOptRuleCall)call);
        }
    }

    @Value.Immutable
    public static interface Config
    extends OpenSearchRuleConfig {
        public static final Config DEFAULT = ImmutableDedupPushdownRule.Config.builder().build().withDescription("Dedup-to-Aggregate").withOperandSupplier(b0 -> b0.operand(LogicalDedup.class).predicate(dedup -> dedup.getKeepEmpty() == false).oneInput(b1 -> b1.operand(LogicalProject.class).oneInput(b2 -> b2.operand(CalciteLogicalIndexScan.class).predicate(Config::tableScanChecker).noInputs())));

        private static boolean tableScanChecker(AbstractCalciteIndexScan scan) {
            return Predicate.not(AbstractCalciteIndexScan::isLimitPushed).and(AbstractCalciteIndexScan::noAggregatePushed).test(scan);
        }

        default public DedupPushdownRule toRule() {
            return new DedupPushdownRule(this);
        }
    }
}

