package io.trino.operator;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import io.airlift.slice.SizeOf;
import io.trino.array.LongBigArray;
import io.trino.spi.Page;
import io.trino.util.HeapTraversal;
import io.trino.util.LongBigArrayFIFOQueue;
import jakarta.annotation.Nullable;
import java.util.Objects;
import java.util.function.LongConsumer;

/* loaded from: input_file:io/trino/operator/GroupedTopNRowNumberAccumulator.class */
public class GroupedTopNRowNumberAccumulator {
    private static final long INSTANCE_SIZE = SizeOf.instanceSize(GroupedTopNRowNumberAccumulator.class);
    private static final long UNKNOWN_INDEX = -1;
    private final GroupIdToHeapBuffer groupIdToHeapBuffer = new GroupIdToHeapBuffer();
    private final HeapNodeBuffer heapNodeBuffer = new HeapNodeBuffer();
    private final HeapTraversal heapTraversal = new HeapTraversal();
    private final RowIdComparisonStrategy strategy;
    private final int topN;
    private final LongConsumer rowIdEvictionListener;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/GroupedTopNRowNumberAccumulator$GroupIdToHeapBuffer.class */
    public static class GroupIdToHeapBuffer {
        private static final long INSTANCE_SIZE = SizeOf.instanceSize(GroupIdToHeapBuffer.class);
        private final LongBigArray heapIndexBuffer = new LongBigArray(-1);
        private final LongBigArray sizeBuffer = new LongBigArray(0);
        private int totalGroups;

        private GroupIdToHeapBuffer() {
        }

        public void allocateGroupIfNeeded(int i) {
            if (this.totalGroups > i) {
                return;
            }
            this.totalGroups = i + 1;
            this.heapIndexBuffer.ensureCapacity(this.totalGroups);
            this.sizeBuffer.ensureCapacity(this.totalGroups);
        }

        public int getTotalGroups() {
            return this.totalGroups;
        }

        public long getHeapRootNodeIndex(int i) {
            return this.heapIndexBuffer.get(i);
        }

        public void setHeapRootNodeIndex(int i, long j) {
            this.heapIndexBuffer.set(i, j);
        }

        public long getHeapSize(int i) {
            return this.sizeBuffer.get(i);
        }

        public void setHeapSize(int i, long j) {
            this.sizeBuffer.set(i, j);
        }

        public void addHeapSize(int i, long j) {
            this.sizeBuffer.add(i, j);
        }

        public void incrementHeapSize(int i) {
            this.sizeBuffer.increment(i);
        }

        public long sizeOf() {
            return INSTANCE_SIZE + this.heapIndexBuffer.sizeOf() + this.sizeBuffer.sizeOf();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/GroupedTopNRowNumberAccumulator$HeapNodeBuffer.class */
    public static class HeapNodeBuffer {
        private static final long INSTANCE_SIZE = SizeOf.instanceSize(HeapNodeBuffer.class);
        private static final int POSITIONS_PER_ENTRY = 3;
        private static final int LEFT_CHILD_HEAP_INDEX_OFFSET = 1;
        private static final int RIGHT_CHILD_HEAP_INDEX_OFFSET = 2;
        private final LongBigArray buffer = new LongBigArray();
        private final LongBigArrayFIFOQueue emptySlots = new LongBigArrayFIFOQueue();
        private long capacity;

        private HeapNodeBuffer() {
        }

        public long allocateNewNode(long j) {
            long j2;
            if (this.emptySlots.isEmpty()) {
                j2 = this.capacity;
                this.capacity++;
                this.buffer.ensureCapacity(this.capacity * 3);
            } else {
                j2 = this.emptySlots.dequeueLong();
            }
            setRowId(j2, j);
            setLeftChildHeapIndex(j2, -1L);
            setRightChildHeapIndex(j2, -1L);
            return j2;
        }

        public void deallocate(long j) {
            this.emptySlots.enqueue(j);
        }

        public long getActiveNodeCount() {
            return this.capacity - this.emptySlots.longSize();
        }

        public long getRowId(long j) {
            return this.buffer.get(j * 3);
        }

        public void setRowId(long j, long j2) {
            this.buffer.set(j * 3, j2);
        }

        public long getLeftChildHeapIndex(long j) {
            return this.buffer.get((j * 3) + 1);
        }

        public void setLeftChildHeapIndex(long j, long j2) {
            this.buffer.set((j * 3) + 1, j2);
        }

        public long getRightChildHeapIndex(long j) {
            return this.buffer.get((j * 3) + 2);
        }

        public void setRightChildHeapIndex(long j, long j2) {
            this.buffer.set((j * 3) + 2, j2);
        }

        public long sizeOf() {
            return INSTANCE_SIZE + this.buffer.sizeOf() + this.emptySlots.sizeOf();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/operator/GroupedTopNRowNumberAccumulator$IntegrityStats.class */
    public static class IntegrityStats {
        private final long maxDepth;
        private final long nodeCount;

        public IntegrityStats(long j, long j2) {
            this.maxDepth = j;
            this.nodeCount = j2;
        }

        public long getMaxDepth() {
            return this.maxDepth;
        }

        public long getNodeCount() {
            return this.nodeCount;
        }
    }

    public GroupedTopNRowNumberAccumulator(RowIdComparisonStrategy rowIdComparisonStrategy, int i, LongConsumer longConsumer) {
        this.strategy = (RowIdComparisonStrategy) Objects.requireNonNull(rowIdComparisonStrategy, "strategy is null");
        Preconditions.checkArgument(i > 0, "topN must be greater than zero");
        this.topN = i;
        this.rowIdEvictionListener = (LongConsumer) Objects.requireNonNull(longConsumer, "rowIdEvictionListener is null");
    }

    public long sizeOf() {
        return INSTANCE_SIZE + this.groupIdToHeapBuffer.sizeOf() + this.heapNodeBuffer.sizeOf() + this.heapTraversal.sizeOf();
    }

    public int findFirstPositionToAdd(Page page, int i, int[] iArr, PageWithPositionComparator pageWithPositionComparator, RowReferencePageManager rowReferencePageManager) {
        int totalGroups = this.groupIdToHeapBuffer.getTotalGroups();
        this.groupIdToHeapBuffer.allocateGroupIfNeeded(i);
        for (int i2 = 0; i2 < page.getPositionCount(); i2++) {
            int i3 = iArr[i2];
            if (i3 >= totalGroups || calculateRootRowNumber(i3) < this.topN) {
                return i2;
            }
            long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(i3);
            if (heapRootNodeIndex == -1) {
                return i2;
            }
            long rowId = this.heapNodeBuffer.getRowId(heapRootNodeIndex);
            if (pageWithPositionComparator.compareTo(page, i2, rowReferencePageManager.getPage(rowId), rowReferencePageManager.getPosition(rowId)) < 0) {
                return i2;
            }
        }
        return -1;
    }

    public boolean add(int i, RowReference rowReference) {
        this.groupIdToHeapBuffer.allocateGroupIfNeeded(i);
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(i);
        if (heapRootNodeIndex == -1 || calculateRootRowNumber(i) < this.topN) {
            heapInsert(i, rowReference.allocateRowId());
            return true;
        }
        if (rowReference.compareTo(this.strategy, this.heapNodeBuffer.getRowId(heapRootNodeIndex)) >= 0) {
            return false;
        }
        heapPopAndInsert(i, rowReference.allocateRowId(), this.rowIdEvictionListener);
        return true;
    }

    public long drainTo(int i, LongBigArray longBigArray) {
        long heapSize = this.groupIdToHeapBuffer.getHeapSize(i);
        longBigArray.ensureCapacity(heapSize);
        long j = heapSize;
        while (true) {
            long j2 = j - 1;
            if (j2 < 0) {
                return heapSize;
            }
            longBigArray.set(j2, peekRootRowId(i));
            heapPop(i, null);
            j = j2;
        }
    }

    private long calculateRootRowNumber(int i) {
        return this.groupIdToHeapBuffer.getHeapSize(i);
    }

    private long peekRootRowId(int i) {
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(i);
        Preconditions.checkArgument(heapRootNodeIndex != -1, "No root to peek");
        return this.heapNodeBuffer.getRowId(heapRootNodeIndex);
    }

    private long getChildIndex(long j, HeapTraversal.Child child) {
        return child == HeapTraversal.Child.LEFT ? this.heapNodeBuffer.getLeftChildHeapIndex(j) : this.heapNodeBuffer.getRightChildHeapIndex(j);
    }

    private void setChildIndex(long j, HeapTraversal.Child child, long j2) {
        if (child == HeapTraversal.Child.LEFT) {
            this.heapNodeBuffer.setLeftChildHeapIndex(j, j2);
        } else {
            this.heapNodeBuffer.setRightChildHeapIndex(j, j2);
        }
    }

    private void heapPop(int i, @Nullable LongConsumer longConsumer) {
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(i);
        Preconditions.checkArgument(heapRootNodeIndex != -1, "Group ID has an empty heap");
        long heapDetachLastInsertionLeaf = heapDetachLastInsertionLeaf(i);
        long rowId = this.heapNodeBuffer.getRowId(heapDetachLastInsertionLeaf);
        this.heapNodeBuffer.deallocate(heapDetachLastInsertionLeaf);
        if (heapDetachLastInsertionLeaf != heapRootNodeIndex) {
            heapPopAndInsert(i, rowId, longConsumer);
        } else if (longConsumer != null) {
            longConsumer.accept(rowId);
        }
    }

    private long heapDetachLastInsertionLeaf(int i) {
        long j = -1;
        HeapTraversal.Child child = null;
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(i);
        this.heapTraversal.resetWithPathTo(this.groupIdToHeapBuffer.getHeapSize(i));
        while (!this.heapTraversal.isTarget()) {
            j = heapRootNodeIndex;
            child = this.heapTraversal.nextChild();
            heapRootNodeIndex = getChildIndex(heapRootNodeIndex, child);
            Verify.verify(heapRootNodeIndex != -1, "Target node must exist", new Object[0]);
        }
        if (j == -1) {
            this.groupIdToHeapBuffer.setHeapRootNodeIndex(i, -1L);
            this.groupIdToHeapBuffer.setHeapSize(i, 0L);
        } else {
            setChildIndex(j, child, -1L);
            this.groupIdToHeapBuffer.addHeapSize(i, -1L);
        }
        return heapRootNodeIndex;
    }

    private void heapInsert(int i, long j) {
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(i);
        if (heapRootNodeIndex == -1) {
            this.groupIdToHeapBuffer.setHeapRootNodeIndex(i, this.heapNodeBuffer.allocateNewNode(j));
            this.groupIdToHeapBuffer.setHeapSize(i, 1L);
            return;
        }
        long j2 = -1;
        HeapTraversal.Child child = null;
        long j3 = heapRootNodeIndex;
        this.heapTraversal.resetWithPathTo(this.groupIdToHeapBuffer.getHeapSize(i) + 1);
        while (!this.heapTraversal.isTarget()) {
            long rowId = this.heapNodeBuffer.getRowId(j3);
            if (this.strategy.compare(j, rowId) > 0) {
                this.heapNodeBuffer.setRowId(j3, j);
                j = rowId;
            }
            j2 = j3;
            child = this.heapTraversal.nextChild();
            j3 = getChildIndex(j3, child);
        }
        Verify.verify((j2 == -1 || child == null) ? false : true, "heap must have at least one node before starting traversal", new Object[0]);
        Verify.verify(j3 == -1, "New child shouldn't exist yet", new Object[0]);
        setChildIndex(j2, child, this.heapNodeBuffer.allocateNewNode(j));
        this.groupIdToHeapBuffer.incrementHeapSize(i);
    }

    private void heapPopAndInsert(int i, long j, @Nullable LongConsumer longConsumer) {
        long j2;
        long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(i);
        Preconditions.checkState(heapRootNodeIndex != -1, "popAndInsert() requires at least a root node");
        long rowId = this.heapNodeBuffer.getRowId(heapRootNodeIndex);
        long j3 = heapRootNodeIndex;
        while (true) {
            j2 = j3;
            long leftChildHeapIndex = this.heapNodeBuffer.getLeftChildHeapIndex(j2);
            if (leftChildHeapIndex == -1) {
                break;
            }
            long rowId2 = this.heapNodeBuffer.getRowId(leftChildHeapIndex);
            long rightChildHeapIndex = this.heapNodeBuffer.getRightChildHeapIndex(j2);
            if (rightChildHeapIndex != -1) {
                long rowId3 = this.heapNodeBuffer.getRowId(rightChildHeapIndex);
                if (this.strategy.compare(rowId3, rowId2) > 0) {
                    leftChildHeapIndex = rightChildHeapIndex;
                    rowId2 = rowId3;
                }
            }
            if (this.strategy.compare(j, rowId2) >= 0) {
                break;
            }
            this.heapNodeBuffer.setRowId(j2, rowId2);
            j3 = leftChildHeapIndex;
        }
        this.heapNodeBuffer.setRowId(j2, j);
        if (longConsumer != null) {
            longConsumer.accept(rowId);
        }
    }

    @VisibleForTesting
    void verifyIntegrity() {
        long j = 0;
        for (int i = 0; i < this.groupIdToHeapBuffer.getTotalGroups(); i++) {
            long heapSize = this.groupIdToHeapBuffer.getHeapSize(i);
            long heapRootNodeIndex = this.groupIdToHeapBuffer.getHeapRootNodeIndex(i);
            Verify.verify(heapRootNodeIndex == -1 || calculateRootRowNumber(i) <= ((long) this.topN), "Max heap has more values than needed", new Object[0]);
            IntegrityStats verifyHeapIntegrity = verifyHeapIntegrity(heapRootNodeIndex);
            Verify.verify(verifyHeapIntegrity.getNodeCount() == heapSize, "Recorded heap size does not match actual heap size", new Object[0]);
            j += verifyHeapIntegrity.getNodeCount();
        }
        Verify.verify(j == this.heapNodeBuffer.getActiveNodeCount(), "Failed to deallocate some unused nodes", new Object[0]);
    }

    private IntegrityStats verifyHeapIntegrity(long j) {
        if (j == -1) {
            return new IntegrityStats(0L, 0L);
        }
        long rowId = this.heapNodeBuffer.getRowId(j);
        long leftChildHeapIndex = this.heapNodeBuffer.getLeftChildHeapIndex(j);
        long rightChildHeapIndex = this.heapNodeBuffer.getRightChildHeapIndex(j);
        if (leftChildHeapIndex != -1) {
            Verify.verify(this.strategy.compare(rowId, this.heapNodeBuffer.getRowId(leftChildHeapIndex)) >= 0, "Max heap invariant violated", new Object[0]);
        }
        if (rightChildHeapIndex != -1) {
            Verify.verify(leftChildHeapIndex != -1, "Left should always be inserted before right", new Object[0]);
            Verify.verify(this.strategy.compare(rowId, this.heapNodeBuffer.getRowId(rightChildHeapIndex)) >= 0, "Max heap invariant violated", new Object[0]);
        }
        IntegrityStats verifyHeapIntegrity = verifyHeapIntegrity(leftChildHeapIndex);
        IntegrityStats verifyHeapIntegrity2 = verifyHeapIntegrity(rightChildHeapIndex);
        Verify.verify(Math.abs(verifyHeapIntegrity.getMaxDepth() - verifyHeapIntegrity2.getMaxDepth()) <= 1, "Heap not balanced", new Object[0]);
        return new IntegrityStats(Math.max(verifyHeapIntegrity.getMaxDepth(), verifyHeapIntegrity2.getMaxDepth()) + 1, verifyHeapIntegrity.getNodeCount() + verifyHeapIntegrity2.getNodeCount() + 1);
    }
}
