package io.trino.operator;

import com.google.common.collect.ImmutableList;
import io.airlift.concurrent.Threads;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.testing.Assertions;
import io.airlift.units.DataSize;
import io.trino.RowPagesBuilder;
import io.trino.SessionTestUtils;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.memory.MemoryPool;
import io.trino.memory.QueryContext;
import io.trino.spi.Page;
import io.trino.spi.QueryId;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.testing.TestingTaskContext;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;

/* loaded from: input_file:io/trino/operator/GroupByHashYieldAssertion.class */
public final class GroupByHashYieldAssertion {
    private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(Threads.daemonThreadsNamed("GroupByHashYieldAssertion-%s"));
    private static final ScheduledExecutorService SCHEDULED_EXECUTOR = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed("GroupByHashYieldAssertion-scheduledExecutor-%s"));

    /* loaded from: input_file:io/trino/operator/GroupByHashYieldAssertion$GroupByHashYieldResult.class */
    public static final class GroupByHashYieldResult {
        private final int yieldCount;
        private final long maxReservedBytes;
        private final List<Page> output;

        public GroupByHashYieldResult(int i, long j, List<Page> list) {
            this.yieldCount = i;
            this.maxReservedBytes = j;
            this.output = (List) Objects.requireNonNull(list, "output is null");
        }

        public int getYieldCount() {
            return this.yieldCount;
        }

        public long getMaxReservedBytes() {
            return this.maxReservedBytes;
        }

        public List<Page> getOutput() {
            return this.output;
        }
    }

    private GroupByHashYieldAssertion() {
    }

    public static List<Page> createPagesWithDistinctHashKeys(Type type, int i, int i2) {
        RowPagesBuilder rowPagesBuilder = RowPagesBuilder.rowPagesBuilder(true, (List<Integer>) ImmutableList.of(0), type);
        for (int i3 = 0; i3 < i; i3++) {
            rowPagesBuilder.addSequencePage(i2, i2 * i3);
        }
        return rowPagesBuilder.build();
    }

    public static GroupByHashYieldResult finishOperatorWithYieldingGroupByHash(List<Page> list, Type type, OperatorFactory operatorFactory, Function<Operator, Integer> function, long j) {
        Assertions.assertLessThan(Long.valueOf(j), 2097152L, "additionalMemoryInBytes should be a relatively small number");
        LinkedList linkedList = new LinkedList();
        QueryId queryId = new QueryId("test_query");
        TaskId taskId = new TaskId(new StageId("another_query", 0), 0, 0);
        MemoryPool memoryPool = new MemoryPool(DataSize.of(1L, DataSize.Unit.GIGABYTE));
        Operator createOperator = operatorFactory.createOperator(TestingTaskContext.createTaskContext(new QueryContext(queryId, DataSize.of(512L, DataSize.Unit.MEGABYTE), memoryPool, new TestingGcMonitor(), EXECUTOR, SCHEDULED_EXECUTOR, DataSize.of(512L, DataSize.Unit.MEGABYTE), new SpillSpaceTracker(DataSize.of(512L, DataSize.Unit.MEGABYTE))), EXECUTOR, SessionTestUtils.TEST_SESSION).addPipelineContext(0, true, true, false).addDriverContext());
        byte[] bArr = new byte[12];
        VariableWidthData variableWidthData = new VariableWidthData();
        int i = 0;
        long j2 = 0;
        for (Page page : list) {
            long j3 = 0;
            if (type == VarcharType.VARCHAR) {
                long retainedSizeBytes = variableWidthData.getRetainedSizeBytes();
                for (int i2 = 0; i2 < page.getPositionCount(); i2++) {
                    variableWidthData.allocate(bArr, 0, page.getBlock(0).getSliceLength(i2));
                }
                j3 = variableWidthData.getRetainedSizeBytes() - retainedSizeBytes;
            }
            org.assertj.core.api.Assertions.assertThat(createOperator.needsInput()).isTrue();
            memoryPool.reserve(taskId, "test", (memoryPool.getFreeBytes() - j) - j3);
            long memoryUsage = createOperator.getOperatorContext().getDriverContext().getMemoryUsage();
            int intValue = function.apply(createOperator).intValue();
            createOperator.addInput(page);
            Page output = createOperator.getOutput();
            if (output != null) {
                linkedList.add(output);
            }
            long memoryUsage2 = createOperator.getOperatorContext().getDriverContext().getMemoryUsage();
            j2 = Math.max(j2, memoryUsage2);
            if (memoryUsage2 < DataSize.of(4L, DataSize.Unit.MEGABYTE).toBytes()) {
                memoryPool.free(taskId, "test", ((Long) memoryPool.getTaskMemoryReservations().get(taskId)).longValue());
                Page output2 = createOperator.getOutput();
                if (output2 != null) {
                    linkedList.add(output2);
                }
            } else {
                long j4 = (memoryUsage2 - memoryUsage) - j3;
                if (createOperator.needsInput()) {
                    org.assertj.core.api.Assertions.assertThat(createOperator.getOperatorContext().isWaitingForMemory().isDone()).isTrue();
                    org.assertj.core.api.Assertions.assertThat(function.apply(createOperator).intValue()).isEqualTo(intValue);
                    Assertions.assertLessThan(Long.valueOf(j4), Long.valueOf(j));
                    memoryPool.free(taskId, "test", ((Long) memoryPool.getTaskMemoryReservations().get(taskId)).longValue());
                } else {
                    i++;
                    org.assertj.core.api.Assertions.assertThat(createOperator.getOperatorContext().isWaitingForMemory().isDone()).isFalse();
                    org.assertj.core.api.Assertions.assertThat(intValue).isEqualTo(function.apply(createOperator).intValue());
                    long hashTableSizeInBytes = type == BigintType.BIGINT ? getHashTableSizeInBytes(type, intValue * 2) : getHashTableSizeInBytes(type, intValue) + intValue;
                    Assertions.assertBetweenInclusive(Long.valueOf(j4), Long.valueOf(hashTableSizeInBytes), Long.valueOf(hashTableSizeInBytes + j));
                    org.assertj.core.api.Assertions.assertThat(createOperator.getOutput()).isNull();
                    memoryPool.free(taskId, "test", ((Long) memoryPool.getTaskMemoryReservations().get(taskId)).longValue());
                    Page output3 = createOperator.getOutput();
                    if (output3 != null) {
                        linkedList.add(output3);
                    }
                    org.assertj.core.api.Assertions.assertThat(createOperator.needsInput()).isTrue();
                    Assertions.assertGreaterThan(function.apply(createOperator), Integer.valueOf(intValue));
                    long memoryUsage3 = createOperator.getOperatorContext().getDriverContext().getMemoryUsage();
                    double hashTableSizeInBytes2 = (memoryUsage3 * 1.0d) / (memoryUsage + getHashTableSizeInBytes(type, intValue));
                    if (hashTableSizeInBytes2 > 1.01d) {
                        Double valueOf = Double.valueOf((memoryUsage3 * 1.0d) / (r0 + j));
                        Assertions.assertBetweenInclusive(valueOf, Double.valueOf(0.97d), Double.valueOf(1.01d), "rehashedMemoryUsage " + memoryUsage3 + ", expectedMemoryUsageAfterRehash: " + valueOf);
                    } else {
                        Assertions.assertBetweenInclusive(Double.valueOf(hashTableSizeInBytes2), Double.valueOf(0.99d), Double.valueOf(1.01d));
                    }
                    org.assertj.core.api.Assertions.assertThat(createOperator.needsInput()).isTrue();
                    org.assertj.core.api.Assertions.assertThat(createOperator.getOperatorContext().isWaitingForMemory().isDone()).isTrue();
                }
            }
        }
        linkedList.addAll(OperatorAssertion.finishOperator(createOperator));
        return new GroupByHashYieldResult(i, j2, linkedList);
    }

    private static long getHashTableSizeInBytes(Type type, int i) {
        return type == BigintType.BIGINT ? i * 18 : i * 46;
    }
}
