/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import io.trino.hive.$internal.com.google.common.collect.ImmutableList;
import io.trino.hive.$internal.com.google.common.collect.Lists;
import io.trino.hive.$internal.org.slf4j.Logger;
import io.trino.hive.$internal.org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin;

public class HiveJoinToMultiJoinRule
extends RelOptRule {
    public static final HiveJoinToMultiJoinRule INSTANCE = new HiveJoinToMultiJoinRule(HiveJoin.class, HiveRelFactories.HIVE_PROJECT_FACTORY);
    private final RelFactories.ProjectFactory projectFactory;
    private static final transient Logger LOG = LoggerFactory.getLogger(HiveJoinToMultiJoinRule.class);

    public HiveJoinToMultiJoinRule(Class<? extends Join> clazz, RelFactories.ProjectFactory projectFactory) {
        super(HiveJoinToMultiJoinRule.operand(clazz, (RelOptRuleOperand)HiveJoinToMultiJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveJoinToMultiJoinRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[]{HiveJoinToMultiJoinRule.operand(RelNode.class, (RelOptRuleOperandChildren)HiveJoinToMultiJoinRule.any())}));
        this.projectFactory = projectFactory;
    }

    public void onMatch(RelOptRuleCall call) {
        HiveJoin newJoin;
        RelNode right;
        RelNode left;
        HiveJoin join = (HiveJoin)call.rel(0);
        RelNode multiJoin = HiveJoinToMultiJoinRule.mergeJoin(join, left = call.rel(1), right = call.rel(2));
        if (multiJoin != null) {
            call.transformTo(multiJoin);
            return;
        }
        RelNode swapped = JoinCommuteRule.swap((Join)join, (boolean)true);
        assert (swapped != null);
        Project topProject = null;
        if (swapped instanceof HiveJoin) {
            newJoin = (HiveJoin)swapped;
        } else {
            topProject = (Project)swapped;
            newJoin = (HiveJoin)swapped.getInput(0);
        }
        multiJoin = HiveJoinToMultiJoinRule.mergeJoin(newJoin, right, left);
        if (multiJoin != null) {
            if (topProject != null) {
                multiJoin = this.projectFactory.createProject(multiJoin, topProject.getChildExps(), topProject.getRowType().getFieldNames());
            }
            call.transformTo(multiJoin);
            return;
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private static RelNode mergeJoin(HiveJoin join, RelNode left, RelNode right) {
        RexNode filters;
        boolean combinable;
        List<RexNode> leftJoinFilters;
        List<JoinRelType> leftJoinTypes;
        List<Pair<Integer, Integer>> leftJoinInputs;
        RexNode leftCondition;
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        ArrayList<RelNode> newInputs = Lists.newArrayList();
        ArrayList<RexNode> newJoinCondition = Lists.newArrayList();
        ArrayList<Pair<Integer, Integer>> joinInputs = Lists.newArrayList();
        ArrayList<JoinRelType> joinTypes = Lists.newArrayList();
        ArrayList<RexNode> joinFilters = Lists.newArrayList();
        if (!(left instanceof HiveJoin)) {
            if (!(left instanceof HiveMultiJoin)) return null;
        }
        if (left instanceof HiveJoin) {
            HiveJoin hj = (HiveJoin)left;
            leftCondition = hj.getCondition();
            leftJoinInputs = ImmutableList.of(Pair.of((Object)0, (Object)1));
            leftJoinTypes = ImmutableList.of(hj.getJoinType());
            leftJoinFilters = ImmutableList.of(hj.getJoinFilter());
            try {
                combinable = HiveJoinToMultiJoinRule.isCombinableJoin(join, hj);
            }
            catch (CalciteSemanticException e) {
                LOG.trace("Failed to merge join-join", e);
                return null;
            }
        }
        HiveMultiJoin hmj = (HiveMultiJoin)left;
        leftCondition = hmj.getCondition();
        leftJoinInputs = hmj.getJoinInputs();
        leftJoinTypes = hmj.getJoinTypes();
        leftJoinFilters = hmj.getJoinFilters();
        try {
            combinable = HiveJoinToMultiJoinRule.isCombinableJoin(join, hmj);
        }
        catch (CalciteSemanticException e) {
            LOG.trace("Failed to merge join-multijoin", e);
            return null;
        }
        if (!combinable) return null;
        newJoinCondition.add(leftCondition);
        for (int i = 0; i < leftJoinInputs.size(); ++i) {
            joinInputs.add(leftJoinInputs.get(i));
            joinTypes.add(leftJoinTypes.get(i));
            joinFilters.add(leftJoinFilters.get(i));
        }
        newInputs.addAll(left.getInputs());
        int numberLeftInputs = newInputs.size();
        newInputs.add(right);
        newJoinCondition.add(join.getCondition());
        if (newJoinCondition.size() == 1) {
            return null;
        }
        ImmutableList<RelDataTypeField> systemFieldList = ImmutableList.of();
        ArrayList<List<RexNode>> joinKeyExprs = new ArrayList<List<RexNode>>();
        ArrayList<Integer> filterNulls = new ArrayList<Integer>();
        for (int i = 0; i < newInputs.size(); ++i) {
            joinKeyExprs.add(new ArrayList());
        }
        try {
            filters = HiveRelOptUtil.splitHiveJoinCondition(systemFieldList, newInputs, join.getCondition(), joinKeyExprs, filterNulls, null);
        }
        catch (CalciteSemanticException e) {
            LOG.trace("Failed to merge joins", e);
            return null;
        }
        ImmutableBitSet.Builder keysInInputsBuilder = ImmutableBitSet.builder();
        for (int i = 0; i < newInputs.size(); ++i) {
            List partialCondition = (List)joinKeyExprs.get(i);
            if (partialCondition.isEmpty()) continue;
            keysInInputsBuilder.set(i);
        }
        ImmutableBitSet keysInInputs = keysInInputsBuilder.build();
        ImmutableBitSet leftReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range((int)numberLeftInputs));
        ImmutableBitSet rightReferencedInputs = keysInInputs.intersect(ImmutableBitSet.range((int)numberLeftInputs, (int)newInputs.size()));
        if (join.getJoinType() != JoinRelType.INNER) {
            if (leftReferencedInputs.cardinality() > 1) return null;
            if (rightReferencedInputs.cardinality() > 1) {
                return null;
            }
        }
        if (join.getJoinType() != JoinRelType.INNER) {
            int leftInput = keysInInputs.nextSetBit(0);
            int rightInput = keysInInputs.nextSetBit(numberLeftInputs);
            joinInputs.add(Pair.of((Object)leftInput, (Object)rightInput));
            joinTypes.add(join.getJoinType());
            joinFilters.add(filters);
        } else {
            Iterator leftInput = leftReferencedInputs.iterator();
            while (leftInput.hasNext()) {
                int i = (Integer)leftInput.next();
                Iterator iterator = rightReferencedInputs.iterator();
                while (iterator.hasNext()) {
                    int j = (Integer)iterator.next();
                    joinInputs.add(Pair.of((Object)i, (Object)j));
                    joinTypes.add(join.getJoinType());
                    joinFilters.add(filters);
                }
            }
        }
        RexNode newCondition = RexUtil.flatten((RexBuilder)rexBuilder, (RexNode)RexUtil.composeConjunction((RexBuilder)rexBuilder, newJoinCondition, (boolean)false));
        ArrayList<RelNode> newInputsArray = Lists.newArrayList(newInputs);
        HiveCalciteUtil.JoinPredicateInfo joinPredInfo = null;
        try {
            joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(newInputsArray, systemFieldList, newCondition);
        }
        catch (CalciteSemanticException e) {
            throw new RuntimeException(e);
        }
        if (joinPredInfo.getEquiJoinPredicateElements().size() < newInputs.size() - 1) {
            return null;
        }
        int i = 0;
        while (i < newInputs.size()) {
            List<RexNode> joinKeys = null;
            for (int j = 0; j < joinPredInfo.getEquiJoinPredicateElements().size(); ++j) {
                List<RexNode> currJoinKeys = joinPredInfo.getEquiJoinPredicateElements().get(j).getJoinExprs(i);
                if (currJoinKeys.isEmpty()) continue;
                if (joinKeys == null) {
                    joinKeys = currJoinKeys;
                    continue;
                }
                if (!joinKeys.containsAll(currJoinKeys)) return null;
                if (currJoinKeys.containsAll(joinKeys)) continue;
                return null;
            }
            ++i;
        }
        return new HiveMultiJoin(join.getCluster(), newInputsArray, newCondition, join.getRowType(), joinInputs, joinTypes, joinFilters, joinPredInfo);
    }

    private static boolean isCombinableJoin(HiveJoin join, HiveJoin leftChildJoin) throws CalciteSemanticException {
        HiveCalciteUtil.JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join, join.getCondition());
        HiveCalciteUtil.JoinPredicateInfo leftChildJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(leftChildJoin, leftChildJoin.getCondition());
        return HiveJoinToMultiJoinRule.isCombinablePredicate(joinPredInfo, leftChildJoinPredInfo, leftChildJoin.getInputs().size());
    }

    private static boolean isCombinableJoin(HiveJoin join, HiveMultiJoin leftChildJoin) throws CalciteSemanticException {
        HiveCalciteUtil.JoinPredicateInfo joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join, join.getCondition());
        HiveCalciteUtil.JoinPredicateInfo leftChildJoinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(leftChildJoin, leftChildJoin.getCondition());
        return HiveJoinToMultiJoinRule.isCombinablePredicate(joinPredInfo, leftChildJoinPredInfo, leftChildJoin.getInputs().size());
    }

    private static boolean isCombinablePredicate(HiveCalciteUtil.JoinPredicateInfo joinPredInfo, HiveCalciteUtil.JoinPredicateInfo leftChildJoinPredInfo, int noLeftChildInputs) throws CalciteSemanticException {
        Set<Integer> keys = joinPredInfo.getProjsJoinKeysInChildSchema(0);
        if (keys.isEmpty()) {
            return false;
        }
        for (int i = 0; i < noLeftChildInputs; ++i) {
            if (!keys.equals(leftChildJoinPredInfo.getProjsJoinKeysInJoinSchema(i))) continue;
            return true;
        }
        return false;
    }
}

