package org.apache.flink.table.planner.plan.rules.physical.batch;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexProgram;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.connector.source.DynamicTableSource;
import org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.plan.abilities.source.AggregatePushDownSpec;
import org.apache.flink.table.planner.plan.abilities.source.ProjectPushDownSpec;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilityContext;
import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
import org.apache.flink.table.planner.plan.schema.TableSourceTable;
import org.apache.flink.table.planner.plan.stats.FlinkStatistic;
import org.apache.flink.table.planner.plan.utils.RexNodeExtractor;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.types.logical.RowType;

/* loaded from: input_file:org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.class */
public abstract class PushLocalAggIntoScanRuleBase extends RelOptRule {
    public PushLocalAggIntoScanRuleBase(RelOptRuleOperand relOptRuleOperand, String str) {
        super(relOptRuleOperand, str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean canPushDown(RelOptRuleCall relOptRuleCall, BatchPhysicalGroupAggregateBase batchPhysicalGroupAggregateBase, BatchPhysicalTableSourceScan batchPhysicalTableSourceScan) {
        if (!((Boolean) ShortcutUtils.unwrapContext(relOptRuleCall.getPlanner()).getTableConfig().get(OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED)).booleanValue() || batchPhysicalGroupAggregateBase.isFinal() || batchPhysicalGroupAggregateBase.getAggCallList().isEmpty()) {
            return false;
        }
        for (AggregateCall aggregateCall : JavaScalaConversionUtil.toJava(batchPhysicalGroupAggregateBase.getAggCallList())) {
            if (aggregateCall.isDistinct() || aggregateCall.isApproximate() || aggregateCall.getArgList().size() > 1 || aggregateCall.hasFilter() || !aggregateCall.getCollation().getFieldCollations().isEmpty()) {
                return false;
            }
        }
        TableSourceTable tableSourceTable = batchPhysicalTableSourceScan.tableSourceTable();
        return tableSourceTable != null && (tableSourceTable.tableSource() instanceof SupportsAggregatePushDown) && Arrays.stream(tableSourceTable.abilitySpecs()).noneMatch(sourceAbilitySpec -> {
            return sourceAbilitySpec instanceof AggregatePushDownSpec;
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void pushLocalAggregateIntoScan(RelOptRuleCall relOptRuleCall, BatchPhysicalGroupAggregateBase batchPhysicalGroupAggregateBase, BatchPhysicalTableSourceScan batchPhysicalTableSourceScan) {
        pushLocalAggregateIntoScan(relOptRuleCall, batchPhysicalGroupAggregateBase, batchPhysicalTableSourceScan, null);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void pushLocalAggregateIntoScan(RelOptRuleCall relOptRuleCall, BatchPhysicalGroupAggregateBase batchPhysicalGroupAggregateBase, BatchPhysicalTableSourceScan batchPhysicalTableSourceScan, int[] iArr) {
        RowType logicalRowType = FlinkTypeFactory.toLogicalRowType(batchPhysicalTableSourceScan.getRowType());
        List<int[]> singletonList = Collections.singletonList(ArrayUtils.addAll(batchPhysicalGroupAggregateBase.grouping(), batchPhysicalGroupAggregateBase.auxGrouping()));
        List<AggregateCall> java = JavaScalaConversionUtil.toJava(batchPhysicalGroupAggregateBase.getAggCallList());
        if (iArr != null) {
            singletonList = translateGroupingArgIndex(singletonList, iArr);
            java = translateAggCallArgIndex(java, iArr);
        }
        RowType logicalRowType2 = FlinkTypeFactory.toLogicalRowType(batchPhysicalGroupAggregateBase.getRowType());
        TableSourceTable tableSourceTable = batchPhysicalTableSourceScan.tableSourceTable();
        DynamicTableSource copy = batchPhysicalTableSourceScan.tableSource().copy();
        if (AggregatePushDownSpec.apply(logicalRowType, singletonList, java, logicalRowType2, copy, SourceAbilityContext.from(batchPhysicalTableSourceScan))) {
            BatchPhysicalTableSourceScan copy2 = batchPhysicalTableSourceScan.copy(batchPhysicalTableSourceScan.getTraitSet(), tableSourceTable.copy(copy, batchPhysicalGroupAggregateBase.getRowType(), new SourceAbilitySpec[]{new AggregatePushDownSpec(logicalRowType, singletonList, java, logicalRowType2)}).copy(FlinkStatistic.UNKNOWN()));
            BatchPhysicalExchange batchPhysicalExchange = (BatchPhysicalExchange) relOptRuleCall.rel(0);
            relOptRuleCall.transformTo(batchPhysicalExchange.copy(batchPhysicalExchange.getTraitSet(), (RelNode) copy2, batchPhysicalExchange.getDistribution()));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isProjectionNotPushedDown(BatchPhysicalTableSourceScan batchPhysicalTableSourceScan) {
        TableSourceTable tableSourceTable = batchPhysicalTableSourceScan.tableSourceTable();
        return tableSourceTable != null && Arrays.stream(tableSourceTable.abilitySpecs()).noneMatch(sourceAbilitySpec -> {
            return sourceAbilitySpec instanceof ProjectPushDownSpec;
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isInputRefOnly(BatchPhysicalCalc batchPhysicalCalc) {
        RexProgram program = batchPhysicalCalc.getProgram();
        if (program.getCondition() == null && !program.getProjectList().isEmpty()) {
            Stream<RexLocalRef> stream = program.getProjectList().stream();
            RexProgram program2 = batchPhysicalCalc.getProgram();
            program2.getClass();
            Stream<R> map = stream.map(program2::expandLocalRef);
            Class<RexInputRef> cls = RexInputRef.class;
            RexInputRef.class.getClass();
            if (map.allMatch((v1) -> {
                return r1.isInstance(v1);
            })) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int[] getRefFiledIndex(BatchPhysicalCalc batchPhysicalCalc) {
        Stream<RexLocalRef> stream = batchPhysicalCalc.getProgram().getProjectList().stream();
        RexProgram program = batchPhysicalCalc.getProgram();
        program.getClass();
        return RexNodeExtractor.extractRefInputFields((List) stream.map(program::expandLocalRef).collect(Collectors.toList()));
    }

    protected List<int[]> translateGroupingArgIndex(List<int[]> list, int[] iArr) {
        ArrayList arrayList = new ArrayList();
        list.forEach(iArr2 -> {
            int[] iArr2 = new int[iArr2.length];
            for (int i = 0; i < iArr2.length; i++) {
                iArr2[i] = iArr[iArr2[i]];
            }
            arrayList.add(iArr2);
        });
        return arrayList;
    }

    protected List<AggregateCall> translateAggCallArgIndex(List<AggregateCall> list, int[] iArr) {
        ArrayList arrayList = new ArrayList();
        list.forEach(aggregateCall -> {
            ArrayList arrayList2 = new ArrayList();
            for (int i = 0; i < aggregateCall.getArgList().size(); i++) {
                arrayList2.add(Integer.valueOf(iArr[aggregateCall.getArgList().get(i).intValue()]));
            }
            arrayList.add(aggregateCall.copy(arrayList2, aggregateCall.filterArg, aggregateCall.collation));
        });
        return arrayList;
    }
}
