/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.runtime.operators.join;

import java.io.Serializable;
import org.apache.flink.api.common.functions.DefaultOpenContext;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.configuration.AlgorithmOptions;
import org.apache.flink.streaming.api.operators.BoundedMultiInput;
import org.apache.flink.streaming.api.operators.InputSelectable;
import org.apache.flink.streaming.api.operators.InputSelection;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryRowData;
import org.apache.flink.table.data.utils.JoinedRowData;
import org.apache.flink.table.runtime.generated.GeneratedJoinCondition;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
import org.apache.flink.table.runtime.generated.JoinCondition;
import org.apache.flink.table.runtime.generated.Projection;
import org.apache.flink.table.runtime.hashtable.BinaryHashPartition;
import org.apache.flink.table.runtime.hashtable.BinaryHashTable;
import org.apache.flink.table.runtime.hashtable.ProbeIterator;
import org.apache.flink.table.runtime.operators.TableStreamOperator;
import org.apache.flink.table.runtime.operators.join.HashJoinType;
import org.apache.flink.table.runtime.operators.join.SortMergeJoinFunction;
import org.apache.flink.table.runtime.typeutils.AbstractRowDataSerializer;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class HashJoinOperator
extends TableStreamOperator<RowData>
implements TwoInputStreamOperator<RowData, RowData, RowData>,
BoundedMultiInput,
InputSelectable {
    private static final Logger LOG = LoggerFactory.getLogger(HashJoinOperator.class);
    private final HashJoinParameter parameter;
    private final boolean reverseJoinFunction;
    private final HashJoinType type;
    private final boolean leftIsBuild;
    private final SortMergeJoinFunction sortMergeJoinFunction;
    private transient BinaryHashTable table;
    transient Collector<RowData> collector;
    transient RowData buildSideNullRow;
    private transient RowData probeSideNullRow;
    private transient JoinedRowData joinedRow;
    private transient boolean buildEnd;
    private transient JoinCondition condition;
    private transient boolean fallbackSMJ;

    HashJoinOperator(HashJoinParameter parameter) {
        this.parameter = parameter;
        this.type = parameter.type;
        this.leftIsBuild = parameter.leftIsBuild;
        this.reverseJoinFunction = parameter.reverseJoinFunction;
        this.sortMergeJoinFunction = parameter.sortMergeJoinFunction;
    }

    @Override
    public void open() throws Exception {
        super.open();
        ClassLoader cl = this.getContainingTask().getUserCodeClassLoader();
        AbstractRowDataSerializer buildSerializer = (AbstractRowDataSerializer)this.getOperatorConfig().getTypeSerializerIn1(this.getUserCodeClassloader());
        AbstractRowDataSerializer probeSerializer = (AbstractRowDataSerializer)this.getOperatorConfig().getTypeSerializerIn2(this.getUserCodeClassloader());
        boolean hashJoinUseBitMaps = (Boolean)this.getContainingTask().getEnvironment().getTaskConfiguration().get(AlgorithmOptions.HASH_JOIN_BLOOM_FILTERS);
        int parallel = this.getRuntimeContext().getTaskInfo().getNumberOfParallelSubtasks();
        this.condition = (JoinCondition)this.parameter.condFuncCode.newInstance(cl);
        this.condition.setRuntimeContext((RuntimeContext)this.getRuntimeContext());
        this.condition.open(DefaultOpenContext.INSTANCE);
        this.table = new BinaryHashTable(this.getContainingTask(), this.parameter.compressionEnabled, this.parameter.compressionBlockSize, buildSerializer, probeSerializer, (Projection)this.parameter.buildProjectionCode.newInstance(cl), (Projection)this.parameter.probeProjectionCode.newInstance(cl), this.getContainingTask().getEnvironment().getMemoryManager(), this.computeMemorySize(), this.getContainingTask().getEnvironment().getIOManager(), this.parameter.buildRowSize, this.parameter.buildRowCount / (long)parallel, hashJoinUseBitMaps, this.type, this.condition, this.reverseJoinFunction, this.parameter.filterNullKeys, this.parameter.tryDistinctBuildRow);
        this.collector = new StreamRecordCollector<RowData>(this.output);
        this.buildSideNullRow = new GenericRowData(buildSerializer.getArity());
        this.probeSideNullRow = new GenericRowData(probeSerializer.getArity());
        this.joinedRow = new JoinedRowData();
        this.buildEnd = false;
        this.fallbackSMJ = false;
        this.getMetricGroup().gauge("memoryUsedSizeInBytes", this.table::getUsedMemoryInBytes);
        this.getMetricGroup().gauge("numSpillFiles", this.table::getNumSpillFiles);
        this.getMetricGroup().gauge("spillInBytes", this.table::getSpillInBytes);
        this.parameter.condFuncCode = null;
        this.parameter.buildProjectionCode = null;
        this.parameter.probeProjectionCode = null;
    }

    public void processElement1(StreamRecord<RowData> element) throws Exception {
        Preconditions.checkState((!this.buildEnd ? 1 : 0) != 0, (Object)"Should not build ended.");
        this.table.putBuildRow((RowData)element.getValue());
    }

    public void processElement2(StreamRecord<RowData> element) throws Exception {
        Preconditions.checkState((boolean)this.buildEnd, (Object)"Should build ended.");
        if (this.table.tryProbe((RowData)element.getValue())) {
            this.joinWithNextKey();
        }
    }

    public InputSelection nextSelection() {
        return this.buildEnd ? InputSelection.SECOND : InputSelection.FIRST;
    }

    public void endInput(int inputId) throws Exception {
        switch (inputId) {
            case 1: {
                Preconditions.checkState((!this.buildEnd ? 1 : 0) != 0, (Object)"Should not build ended.");
                LOG.info("Finish build phase.");
                this.buildEnd = true;
                this.table.endBuild();
                break;
            }
            case 2: {
                Preconditions.checkState((boolean)this.buildEnd, (Object)"Should build ended.");
                LOG.info("Finish probe phase.");
                while (this.table.nextMatching()) {
                    this.joinWithNextKey();
                }
                LOG.info("Finish rebuild phase.");
                this.fallbackSMJProcessPartition();
            }
        }
    }

    private void joinWithNextKey() throws Exception {
        this.join(this.table.getBuildSideIterator(), this.table.getCurrentProbeRow());
    }

    public abstract void join(RowIterator<BinaryRowData> var1, RowData var2) throws Exception;

    void innerJoin(RowIterator<BinaryRowData> buildIter, RowData probeRow) throws Exception {
        this.collect((RowData)buildIter.getRow(), probeRow);
        while (buildIter.advanceNext()) {
            this.collect((RowData)buildIter.getRow(), probeRow);
        }
    }

    void buildOuterJoin(RowIterator<BinaryRowData> buildIter) throws Exception {
        this.collect((RowData)buildIter.getRow(), this.probeSideNullRow);
        while (buildIter.advanceNext()) {
            this.collect((RowData)buildIter.getRow(), this.probeSideNullRow);
        }
    }

    void collect(RowData row1, RowData row2) throws Exception {
        if (this.reverseJoinFunction) {
            this.collector.collect((Object)this.joinedRow.replace(row2, row1));
        } else {
            this.collector.collect((Object)this.joinedRow.replace(row1, row2));
        }
    }

    public void close() throws Exception {
        super.close();
        this.closeHashTable();
        if (this.condition != null) {
            this.condition.close();
        }
        if (this.fallbackSMJ) {
            this.sortMergeJoinFunction.close();
        }
    }

    private void closeHashTable() {
        if (this.table != null) {
            this.table.close();
            this.table.free();
            this.table = null;
        }
    }

    private void fallbackSMJProcessPartition() throws Exception {
        if (!this.table.getPartitionsPendingForSMJ().isEmpty()) {
            this.table.releaseMemoryCacheForSMJ();
            LOG.info("Fallback to sort merge join to process spilled partitions.");
            this.initialSortMergeJoinFunction();
            this.fallbackSMJ = true;
            for (BinaryHashPartition p : this.table.getPartitionsPendingForSMJ()) {
                BinaryRowData probeNext;
                RowIterator buildSideIter = this.table.getSpilledPartitionBuildSideIter(p);
                while (buildSideIter.advanceNext()) {
                    this.processSortMergeJoinElement1((RowData)buildSideIter.getRow());
                }
                ProbeIterator probeIter = this.table.getSpilledPartitionProbeSideIter(p);
                while ((probeNext = probeIter.next()) != null) {
                    this.processSortMergeJoinElement2((RowData)probeNext);
                }
            }
            this.closeHashTable();
            this.sortMergeJoinFunction.endInput(1);
            this.sortMergeJoinFunction.endInput(2);
            LOG.info("Finish sort merge join for spilled partitions.");
        }
    }

    private void initialSortMergeJoinFunction() throws Exception {
        this.sortMergeJoinFunction.open(true, this.getContainingTask(), this.getOperatorConfig(), (StreamRecordCollector)this.collector, this.computeMemorySize(), (RuntimeContext)this.getRuntimeContext(), this.getMetricGroup());
    }

    private void processSortMergeJoinElement1(RowData rowData) throws Exception {
        if (this.leftIsBuild) {
            this.sortMergeJoinFunction.processElement1(rowData);
        } else {
            this.sortMergeJoinFunction.processElement2(rowData);
        }
    }

    private void processSortMergeJoinElement2(RowData rowData) throws Exception {
        if (this.leftIsBuild) {
            this.sortMergeJoinFunction.processElement2(rowData);
        } else {
            this.sortMergeJoinFunction.processElement1(rowData);
        }
    }

    public static HashJoinOperator newHashJoinOperator(HashJoinType type, boolean leftIsBuild, boolean compressionEnable, int compressionBlockSize, GeneratedJoinCondition condFuncCode, boolean reverseJoinFunction, boolean[] filterNullKeys, GeneratedProjection buildProjectionCode, GeneratedProjection probeProjectionCode, boolean tryDistinctBuildRow, int buildRowSize, long buildRowCount, long probeRowCount, RowType keyType, SortMergeJoinFunction sortMergeJoinFunction) {
        HashJoinParameter parameter = new HashJoinParameter(type, leftIsBuild, compressionEnable, compressionBlockSize, condFuncCode, reverseJoinFunction, filterNullKeys, buildProjectionCode, probeProjectionCode, tryDistinctBuildRow, buildRowSize, buildRowCount, probeRowCount, keyType, sortMergeJoinFunction);
        switch (type) {
            case INNER: {
                return new InnerHashJoinOperator(parameter);
            }
            case BUILD_OUTER: {
                return new BuildOuterHashJoinOperator(parameter);
            }
            case PROBE_OUTER: {
                return new ProbeOuterHashJoinOperator(parameter);
            }
            case FULL_OUTER: {
                return new FullOuterHashJoinOperator(parameter);
            }
            case SEMI: {
                return new SemiHashJoinOperator(parameter);
            }
            case ANTI: {
                return new AntiHashJoinOperator(parameter);
            }
            case BUILD_LEFT_SEMI: 
            case BUILD_LEFT_ANTI: {
                return new BuildLeftSemiOrAntiHashJoinOperator(parameter);
            }
        }
        throw new IllegalArgumentException("invalid: " + type);
    }

    private static class BuildLeftSemiOrAntiHashJoinOperator
    extends HashJoinOperator {
        BuildLeftSemiOrAntiHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRowData> buildIter, RowData probeRow) throws Exception {
            block4: {
                if (!buildIter.advanceNext()) break block4;
                if (probeRow != null) {
                    while (buildIter.advanceNext()) {
                    }
                } else {
                    this.collector.collect((Object)buildIter.getRow());
                    while (buildIter.advanceNext()) {
                        this.collector.collect((Object)buildIter.getRow());
                    }
                }
            }
        }
    }

    private static class AntiHashJoinOperator
    extends HashJoinOperator {
        AntiHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRowData> buildIter, RowData probeRow) throws Exception {
            Preconditions.checkNotNull((Object)probeRow);
            if (!buildIter.advanceNext()) {
                this.collector.collect((Object)probeRow);
            }
        }
    }

    private static class SemiHashJoinOperator
    extends HashJoinOperator {
        SemiHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRowData> buildIter, RowData probeRow) throws Exception {
            Preconditions.checkNotNull((Object)probeRow);
            if (buildIter.advanceNext()) {
                this.collector.collect((Object)probeRow);
            }
        }
    }

    private static class FullOuterHashJoinOperator
    extends HashJoinOperator {
        FullOuterHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRowData> buildIter, RowData probeRow) throws Exception {
            if (buildIter.advanceNext()) {
                if (probeRow != null) {
                    this.innerJoin(buildIter, probeRow);
                } else {
                    this.buildOuterJoin(buildIter);
                }
            } else if (probeRow != null) {
                this.collect(this.buildSideNullRow, probeRow);
            }
        }
    }

    private static class ProbeOuterHashJoinOperator
    extends HashJoinOperator {
        ProbeOuterHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRowData> buildIter, RowData probeRow) throws Exception {
            if (buildIter.advanceNext()) {
                if (probeRow != null) {
                    this.innerJoin(buildIter, probeRow);
                }
            } else if (probeRow != null) {
                this.collect(this.buildSideNullRow, probeRow);
            }
        }
    }

    private static class BuildOuterHashJoinOperator
    extends HashJoinOperator {
        BuildOuterHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRowData> buildIter, RowData probeRow) throws Exception {
            if (buildIter.advanceNext()) {
                if (probeRow != null) {
                    this.innerJoin(buildIter, probeRow);
                } else {
                    this.buildOuterJoin(buildIter);
                }
            }
        }
    }

    private static class InnerHashJoinOperator
    extends HashJoinOperator {
        InnerHashJoinOperator(HashJoinParameter parameter) {
            super(parameter);
        }

        @Override
        public void join(RowIterator<BinaryRowData> buildIter, RowData probeRow) throws Exception {
            if (buildIter.advanceNext() && probeRow != null) {
                this.innerJoin(buildIter, probeRow);
            }
        }
    }

    static class HashJoinParameter
    implements Serializable {
        HashJoinType type;
        boolean leftIsBuild;
        boolean compressionEnabled;
        int compressionBlockSize;
        GeneratedJoinCondition condFuncCode;
        boolean reverseJoinFunction;
        boolean[] filterNullKeys;
        GeneratedProjection buildProjectionCode;
        GeneratedProjection probeProjectionCode;
        boolean tryDistinctBuildRow;
        int buildRowSize;
        long buildRowCount;
        long probeRowCount;
        RowType keyType;
        SortMergeJoinFunction sortMergeJoinFunction;

        HashJoinParameter(HashJoinType type, boolean leftIsBuild, boolean compressionEnabled, int compressionBlockSize, GeneratedJoinCondition condFuncCode, boolean reverseJoinFunction, boolean[] filterNullKeys, GeneratedProjection buildProjectionCode, GeneratedProjection probeProjectionCode, boolean tryDistinctBuildRow, int buildRowSize, long buildRowCount, long probeRowCount, RowType keyType, SortMergeJoinFunction sortMergeJoinFunction) {
            this.type = type;
            this.leftIsBuild = leftIsBuild;
            this.compressionEnabled = compressionEnabled;
            this.compressionBlockSize = compressionBlockSize;
            this.condFuncCode = condFuncCode;
            this.reverseJoinFunction = reverseJoinFunction;
            this.filterNullKeys = filterNullKeys;
            this.buildProjectionCode = buildProjectionCode;
            this.probeProjectionCode = probeProjectionCode;
            this.tryDistinctBuildRow = tryDistinctBuildRow;
            this.buildRowSize = buildRowSize;
            this.buildRowCount = buildRowCount;
            this.probeRowCount = probeRowCount;
            this.keyType = keyType;
            this.sortMergeJoinFunction = sortMergeJoinFunction;
        }
    }
}

