package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableRewriteMultiJoinConditionRule;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.util.Preconditions;
import org.immutables.value.Value;

@Value.Enclosing
/* loaded from: input_file:org/apache/flink/table/planner/plan/rules/logical/RewriteMultiJoinConditionRule.class */
public class RewriteMultiJoinConditionRule extends RelRule<RewriteMultiJoinConditionRuleConfig> {
    public static final RewriteMultiJoinConditionRule INSTANCE = RewriteMultiJoinConditionRuleConfig.DEFAULT.toRule();

    @Value.Immutable(singleton = false)
    /* loaded from: input_file:org/apache/flink/table/planner/plan/rules/logical/RewriteMultiJoinConditionRule$RewriteMultiJoinConditionRuleConfig.class */
    public interface RewriteMultiJoinConditionRuleConfig extends RelRule.Config {
        public static final RewriteMultiJoinConditionRuleConfig DEFAULT = ImmutableRewriteMultiJoinConditionRule.RewriteMultiJoinConditionRuleConfig.builder().build().withOperandSupplier(operandBuilder -> {
            return operandBuilder.operand(MultiJoin.class).anyInputs();
        }).withDescription("RewriteMultiJoinConditionRule");

        @Override // org.apache.calcite.plan.RelRule.Config
        default RewriteMultiJoinConditionRule toRule() {
            return new RewriteMultiJoinConditionRule(this);
        }
    }

    private RewriteMultiJoinConditionRule(RewriteMultiJoinConditionRuleConfig rewriteMultiJoinConditionRuleConfig) {
        super(rewriteMultiJoinConditionRuleConfig);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        MultiJoin multiJoin = (MultiJoin) relOptRuleCall.rel(0);
        return multiJoin.getInputs().size() > ((Integer) ShortcutUtils.unwrapContext(multiJoin).getTableConfig().get(OptimizerConfigOptions.TABLE_OPTIMIZER_BUSHY_JOIN_REORDER_THRESHOLD)).intValue() && !multiJoin.isFullOuterJoin() && multiJoin.getJoinTypes().stream().allMatch(joinRelType -> {
            return joinRelType == JoinRelType.INNER;
        }) && ((List) partitionJoinFilters(multiJoin).f0).size() > 1;
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        MultiJoin multiJoin = (MultiJoin) relOptRuleCall.rel(0);
        Tuple2<List<RexNode>, List<RexNode>> partitionJoinFilters = partitionJoinFilters(multiJoin);
        List list = (List) partitionJoinFilters.f0;
        List list2 = (List) partitionJoinFilters.f1;
        HashMap hashMap = new HashMap();
        list.stream().filter(rexNode -> {
            return rexNode instanceof RexCall;
        }).forEach(rexNode2 -> {
            Preconditions.checkState(rexNode2.isA(SqlKind.EQUALS));
            RexNode rexNode2 = ((RexCall) rexNode2).getOperands().get(0);
            RexNode rexNode3 = ((RexCall) rexNode2).getOperands().get(1);
            ((List) hashMap.computeIfAbsent(rexNode2, rexNode4 -> {
                return new ArrayList();
            })).add(rexNode3);
            ((List) hashMap.computeIfAbsent(rexNode3, rexNode5 -> {
                return new ArrayList();
            })).add(rexNode2);
        });
        List list3 = (List) hashMap.values().stream().filter(list4 -> {
            return list4.size() > 1;
        }).collect(Collectors.toList());
        if (list3.isEmpty()) {
            return;
        }
        ArrayList arrayList = new ArrayList(list);
        RexBuilder rexBuilder = multiJoin.getCluster().getRexBuilder();
        list3.forEach(list5 -> {
            IntStream.range(0, list5.size()).forEach(i -> {
                RexNode rexNode3 = (RexNode) list5.get(i);
                list5.subList(i + 1, list5.size()).forEach(rexNode4 -> {
                    RexNode makeCall = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, rexNode3, rexNode4);
                    if (containEquiJoinFilter(makeCall, arrayList)) {
                        return;
                    }
                    arrayList.add(makeCall);
                });
            });
        });
        if (arrayList.size() == list.size()) {
            return;
        }
        relOptRuleCall.transformTo(new MultiJoin(multiJoin.getCluster(), multiJoin.getInputs(), relOptRuleCall.builder().and((Iterable<? extends RexNode>) Stream.concat(arrayList.stream(), list2.stream()).collect(Collectors.toList())), multiJoin.getRowType(), multiJoin.isFullOuterJoin(), multiJoin.getOuterJoinConditions(), multiJoin.getJoinTypes(), multiJoin.getProjFields(), multiJoin.getJoinFieldRefCountsMap(), multiJoin.getPostJoinFilter()));
    }

    private boolean containEquiJoinFilter(RexNode rexNode, List<RexNode> list) {
        return list.stream().anyMatch(rexNode2 -> {
            return rexNode2.equals(rexNode);
        });
    }

    private Tuple2<List<RexNode>, List<RexNode>> partitionJoinFilters(MultiJoin multiJoin) {
        Map map = (Map) RelOptUtil.conjunctions(multiJoin.getJoinFilter()).stream().collect(Collectors.partitioningBy(rexNode -> {
            return rexNode.isA(SqlKind.EQUALS);
        }));
        return new Tuple2<>((List) map.get(true), (List) map.get(false));
    }
}
