package org.apache.flink.runtime.state.heap;

import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.flink.runtime.state.InternalPriorityQueue;
import org.apache.flink.runtime.state.KeyExtractorFunction;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
import org.apache.flink.runtime.state.PriorityComparator;
import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.IOUtils;
import org.apache.flink.util.Preconditions;

/* JADX WARN: Incorrect field signature: [TPQ; */
/* loaded from: input_file:org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.class */
public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueue<T> & HeapPriorityQueueElement> implements InternalPriorityQueue<T>, KeyGroupedInternalPriorityQueue<T> {

    @Nonnull
    private final HeapPriorityQueue<PQ> heapOfKeyGroupedHeaps;

    @Nonnull
    private final InternalPriorityQueue[] keyGroupedHeaps;

    @Nonnull
    private final KeyExtractorFunction<T> keyExtractor;

    @Nonnegative
    private final int totalKeyGroups;

    @Nonnegative
    private final int firstKeyGroup;

    /* loaded from: input_file:org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue$InternalPriorityQueueComparator.class */
    private static final class InternalPriorityQueueComparator<T, Q extends InternalPriorityQueue<T>> implements PriorityComparator<Q> {

        @Nonnull
        private final PriorityComparator<T> elementPriorityComparator;

        InternalPriorityQueueComparator(@Nonnull PriorityComparator<T> priorityComparator) {
            this.elementPriorityComparator = priorityComparator;
        }

        /* JADX WARN: Multi-variable type inference failed */
        @Override // org.apache.flink.runtime.state.PriorityComparator
        public int comparePriority(Q q, Q q2) {
            Object peek = q.peek();
            Object peek2 = q2.peek();
            if (peek == null) {
                return peek2 == null ? 0 : 1;
            }
            if (peek2 == null) {
                return -1;
            }
            return this.elementPriorityComparator.comparePriority(peek, peek2);
        }
    }

    /* JADX WARN: Incorrect field signature: [TPQS; */
    /* loaded from: input_file:org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue$KeyGroupConcatenationIterator.class */
    private static final class KeyGroupConcatenationIterator<T, PQS extends InternalPriorityQueue<T> & HeapPriorityQueueElement> implements CloseableIterator<T> {

        @Nonnull
        private final InternalPriorityQueue[] keyGroupLists;

        @Nonnegative
        private int index = 0;

        @Nonnull
        private CloseableIterator<T> current = CloseableIterator.empty();

        /* JADX WARN: Incorrect types in method signature: ([TPQS;)V */
        private KeyGroupConcatenationIterator(@Nonnull InternalPriorityQueue[] internalPriorityQueueArr) {
            this.keyGroupLists = internalPriorityQueueArr;
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            boolean z;
            boolean hasNext = this.current.hasNext();
            while (true) {
                z = hasNext;
                if (z || this.index >= this.keyGroupLists.length) {
                    break;
                }
                IOUtils.closeQuietly(this.current);
                InternalPriorityQueue[] internalPriorityQueueArr = this.keyGroupLists;
                int i = this.index;
                this.index = i + 1;
                this.current = internalPriorityQueueArr[i].iterator();
                hasNext = this.current.hasNext();
            }
            return z;
        }

        @Override // java.util.Iterator
        public T next() {
            return this.current.next();
        }

        @Override // java.lang.AutoCloseable
        public void close() throws Exception {
            this.current.close();
        }
    }

    /* loaded from: input_file:org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue$PartitionQueueSetFactory.class */
    public interface PartitionQueueSetFactory<T, PQS extends InternalPriorityQueue<T> & HeapPriorityQueueElement> {
        /* JADX WARN: Incorrect return type in method signature: (IILorg/apache/flink/runtime/state/KeyExtractorFunction<TT;>;Lorg/apache/flink/runtime/state/PriorityComparator<TT;>;)TPQS; */
        @Nonnull
        InternalPriorityQueue create(@Nonnegative int i, @Nonnegative int i2, @Nonnull KeyExtractorFunction keyExtractorFunction, @Nonnull PriorityComparator priorityComparator);
    }

    public KeyGroupPartitionedPriorityQueue(@Nonnull KeyExtractorFunction<T> keyExtractorFunction, @Nonnull PriorityComparator<T> priorityComparator, @Nonnull PartitionQueueSetFactory<T, PQ> partitionQueueSetFactory, @Nonnull KeyGroupRange keyGroupRange, @Nonnegative int i) {
        this.keyExtractor = keyExtractorFunction;
        this.totalKeyGroups = i;
        this.firstKeyGroup = keyGroupRange.getStartKeyGroup();
        this.keyGroupedHeaps = new InternalPriorityQueue[keyGroupRange.getNumberOfKeyGroups()];
        this.heapOfKeyGroupedHeaps = new HeapPriorityQueue<>(new InternalPriorityQueueComparator(priorityComparator), keyGroupRange.getNumberOfKeyGroups());
        for (int i2 = 0; i2 < this.keyGroupedHeaps.length; i2++) {
            InternalPriorityQueue create = partitionQueueSetFactory.create(this.firstKeyGroup + i2, i, keyExtractorFunction, priorityComparator);
            this.keyGroupedHeaps[i2] = create;
            this.heapOfKeyGroupedHeaps.add((HeapPriorityQueue<PQ>) create);
        }
    }

    @Override // org.apache.flink.runtime.state.InternalPriorityQueue
    @Nullable
    public T poll() {
        InternalPriorityQueue internalPriorityQueue = (InternalPriorityQueue) this.heapOfKeyGroupedHeaps.peek();
        T t = (T) internalPriorityQueue.poll();
        this.heapOfKeyGroupedHeaps.adjustModifiedElement(internalPriorityQueue);
        return t;
    }

    @Override // org.apache.flink.runtime.state.InternalPriorityQueue
    @Nullable
    public T peek() {
        return (T) ((InternalPriorityQueue) this.heapOfKeyGroupedHeaps.peek()).peek();
    }

    @Override // org.apache.flink.runtime.state.InternalPriorityQueue
    public boolean add(@Nonnull T t) {
        InternalPriorityQueue keyGroupSubHeapForElement = getKeyGroupSubHeapForElement(t);
        if (!keyGroupSubHeapForElement.add(t)) {
            return false;
        }
        this.heapOfKeyGroupedHeaps.adjustModifiedElement(keyGroupSubHeapForElement);
        return t.equals(peek());
    }

    @Override // org.apache.flink.runtime.state.InternalPriorityQueue
    public boolean remove(@Nonnull T t) {
        InternalPriorityQueue keyGroupSubHeapForElement = getKeyGroupSubHeapForElement(t);
        T peek = peek();
        if (!keyGroupSubHeapForElement.remove(t)) {
            return false;
        }
        this.heapOfKeyGroupedHeaps.adjustModifiedElement(keyGroupSubHeapForElement);
        return t.equals(peek);
    }

    @Override // org.apache.flink.runtime.state.InternalPriorityQueue
    public boolean isEmpty() {
        return peek() == null;
    }

    @Override // org.apache.flink.runtime.state.InternalPriorityQueue
    public int size() {
        int i = 0;
        for (InternalPriorityQueue internalPriorityQueue : this.keyGroupedHeaps) {
            i += internalPriorityQueue.size();
        }
        return i;
    }

    @Override // org.apache.flink.runtime.state.InternalPriorityQueue
    public void addAll(@Nullable Collection<? extends T> collection) {
        if (collection == null) {
            return;
        }
        Iterator<? extends T> it = collection.iterator();
        while (it.hasNext()) {
            add(it.next());
        }
    }

    @Override // org.apache.flink.runtime.state.InternalPriorityQueue
    @Nonnull
    public CloseableIterator<T> iterator() {
        return new KeyGroupConcatenationIterator(this.keyGroupedHeaps);
    }

    /* JADX WARN: Incorrect return type in method signature: (TT;)TPQ; */
    /* JADX WARN: Multi-variable type inference failed */
    private InternalPriorityQueue getKeyGroupSubHeapForElement(Object obj) {
        return this.keyGroupedHeaps[computeKeyGroupIndex(obj)];
    }

    private int computeKeyGroupIndex(T t) {
        return globalKeyGroupToLocalIndex(KeyGroupRangeAssignment.assignToKeyGroup(this.keyExtractor.extractKeyFromElement(t), this.totalKeyGroups));
    }

    private int globalKeyGroupToLocalIndex(int i) {
        int i2 = i - this.firstKeyGroup;
        Preconditions.checkArgument(i2 >= 0 && i2 < this.keyGroupedHeaps.length, "key group from %s to %s does not contain %s", Integer.valueOf(this.firstKeyGroup), Integer.valueOf(this.firstKeyGroup + this.keyGroupedHeaps.length), Integer.valueOf(i));
        return i2;
    }

    @Override // org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue
    @Nonnull
    public Set<T> getSubsetForKeyGroup(int i) {
        HashSet hashSet = new HashSet();
        try {
            CloseableIterator<T> it = this.keyGroupedHeaps[globalKeyGroupToLocalIndex(i)].iterator();
            while (it.hasNext()) {
                try {
                    hashSet.add(it.next());
                } finally {
                }
            }
            if (it != null) {
                it.close();
            }
            return hashSet;
        } catch (Exception e) {
            throw new FlinkRuntimeException("Exception while iterating key group.", e);
        }
    }
}
