/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.HashMap;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.IntList;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Permutation;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;

public class AggregateProjectPullUpConstantsRule
extends RelOptRule {
    public static final AggregateProjectPullUpConstantsRule INSTANCE = new AggregateProjectPullUpConstantsRule();

    private AggregateProjectPullUpConstantsRule() {
        super(AggregateProjectPullUpConstantsRule.operand(LogicalAggregate.class, null, Aggregate.IS_SIMPLE, AggregateProjectPullUpConstantsRule.operand(LogicalProject.class, AggregateProjectPullUpConstantsRule.any()), new RelOptRuleOperand[0]));
    }

    public void onMatch(RelOptRuleCall call) {
        LogicalAggregate newAggregate;
        LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
        LogicalProject child = (LogicalProject)call.rel(1);
        int groupCount = aggregate.getGroupCount();
        if (groupCount == 1) {
            return;
        }
        RexProgram program = RexProgram.create(child.getInput().getRowType(), child.getProjects(), null, child.getRowType(), child.getCluster().getRexBuilder());
        RelDataType childRowType = child.getRowType();
        IntList constantList = new IntList();
        HashMap<Integer, RexNode> constants = new HashMap<Integer, RexNode>();
        for (int i : aggregate.getGroupSet()) {
            RexLocalRef ref = program.getProjectList().get(i);
            if (!program.isConstant(ref)) continue;
            constantList.add(i);
            constants.put(i, program.gatherExpr(ref));
        }
        if (constantList.size() == 0) {
            return;
        }
        if (groupCount == constantList.size()) {
            constantList.remove(0);
        }
        int newGroupCount = groupCount - constantList.size();
        if ((Integer)constantList.get(0) == newGroupCount) {
            ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
            for (AggregateCall aggCall : aggregate.getAggCallList()) {
                newAggCalls.add(aggCall.adaptTo(child, aggCall.getArgList(), groupCount, newGroupCount));
            }
            newAggregate = new LogicalAggregate(aggregate.getCluster(), child, false, ImmutableBitSet.range(newGroupCount), null, newAggCalls);
        } else {
            Permutation mapping = new Permutation(childRowType.getFieldCount());
            mapping.identity();
            int groupOrdinal = 0;
            int constOrdinal = newGroupCount;
            for (int i = 0; i < groupCount; ++i) {
                if (i >= groupCount) {
                    mapping.set(i, i);
                    continue;
                }
                if (constants.containsKey(i)) {
                    mapping.set(i, constOrdinal++);
                    continue;
                }
                mapping.set(i, groupOrdinal++);
            }
            RelNode project = AggregateProjectPullUpConstantsRule.createProjection(mapping, child);
            ArrayList<AggregateCall> newAggCalls = new ArrayList<AggregateCall>();
            for (AggregateCall aggCall : aggregate.getAggCallList()) {
                int argCount = aggCall.getArgList().size();
                ArrayList<Integer> args = new ArrayList<Integer>(argCount);
                for (int j = 0; j < argCount; ++j) {
                    Integer arg = aggCall.getArgList().get(j);
                    args.add(mapping.getTarget(arg));
                }
                newAggCalls.add(aggCall.adaptTo(project, args, groupCount, newGroupCount));
            }
            newAggregate = new LogicalAggregate(aggregate.getCluster(), project, false, ImmutableBitSet.range(newGroupCount), null, newAggCalls);
        }
        RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
        ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
        int source = 0;
        for (RelDataTypeField field : aggregate.getRowType().getFieldList()) {
            RexNode expr;
            int i = field.getIndex();
            if (i >= groupCount) {
                expr = rexBuilder.makeInputRef(newAggregate, i - constantList.size());
            } else if (constantList.contains(i)) {
                expr = (RexNode)constants.get(i);
            } else {
                expr = rexBuilder.makeInputRef(newAggregate, source);
                ++source;
            }
            projects.add(Pair.of(expr, field.getName()));
        }
        RelNode inverseProject = RelOptUtil.createProject((RelNode)newAggregate, projects, false);
        call.transformTo(inverseProject);
    }

    private static RelNode createProjection(Mapping mapping, RelNode child) {
        assert (mapping.getMappingType().isA(MappingType.INVERSE_SURJECTION));
        RelDataType childRowType = child.getRowType();
        assert (mapping.getSourceCount() == childRowType.getFieldCount());
        ArrayList<Pair<RexNode, String>> projects = new ArrayList<Pair<RexNode, String>>();
        for (int target = 0; target < mapping.getTargetCount(); ++target) {
            int source = mapping.getSource(target);
            RexBuilder rexBuilder = child.getCluster().getRexBuilder();
            projects.add(Pair.of(rexBuilder.makeInputRef(child, source), childRowType.getFieldList().get(source).getName()));
        }
        return RelOptUtil.createProject(child, projects, false);
    }
}

