package io.trino.operator.aggregation.state;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.trino.array.BlockBigArray;
import io.trino.array.BooleanBigArray;
import io.trino.array.ByteBigArray;
import io.trino.array.DoubleBigArray;
import io.trino.array.IntBigArray;
import io.trino.array.LongBigArray;
import io.trino.array.ReferenceCountMap;
import io.trino.array.SliceBigArray;
import io.trino.block.BlockAssertions;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.BlockBuilderStatus;
import io.trino.spi.block.MapBlockBuilder;
import io.trino.spi.block.RowBlockBuilder;
import io.trino.spi.block.VariableWidthBlockBuilder;
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.function.GroupedAccumulatorState;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import io.trino.util.Reflection;
import io.trino.util.StructuralTestUtil;
import java.lang.reflect.Field;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler.class */
public class TestStateCompiler {

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$BooleanState.class */
    public interface BooleanState extends AccumulatorState {
        boolean isBoolean();

        void setBoolean(boolean z);
    }

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$ByteState.class */
    public interface ByteState extends AccumulatorState {
        byte getByte();

        void setByte(byte b);
    }

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$SliceState.class */
    public interface SliceState extends AccumulatorState {
        Slice getSlice();

        void setSlice(Slice slice);
    }

    /* loaded from: input_file:io/trino/operator/aggregation/state/TestStateCompiler$TestComplexState.class */
    public interface TestComplexState extends AccumulatorState {
        double getDouble();

        void setDouble(double d);

        boolean getBoolean();

        void setBoolean(boolean z);

        long getLong();

        void setLong(long j);

        byte getByte();

        void setByte(byte b);

        int getInt();

        void setInt(int i);

        Slice getSlice();

        void setSlice(Slice slice);

        Slice getAnotherSlice();

        void setAnotherSlice(Slice slice);

        Slice getYetAnotherSlice();

        void setYetAnotherSlice(Slice slice);

        Block getBlock();

        void setBlock(Block block);

        Block getAnotherBlock();

        void setAnotherBlock(Block block);
    }

    @Test
    public void testPrimitiveNullableLongSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(NullableLongState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(NullableLongState.class);
        NullableLongState createSingleState = generateStateFactory.createSingleState();
        NullableLongState createSingleState2 = generateStateFactory.createSingleState();
        createSingleState.setValue(2L);
        createSingleState.setNull(false);
        BlockBuilder createBlockBuilder = BigintType.BIGINT.createBlockBuilder((BlockBuilderStatus) null, 2);
        generateStateSerializer.serialize(createSingleState, createBlockBuilder);
        createSingleState.setNull(true);
        generateStateSerializer.serialize(createSingleState, createBlockBuilder);
        Block build = createBlockBuilder.build();
        Assert.assertFalse(build.isNull(0));
        Assert.assertEquals(BigintType.BIGINT.getLong(build, 0), createSingleState.getValue());
        generateStateSerializer.deserialize(build, 0, createSingleState2);
        Assert.assertEquals(createSingleState2.getValue(), createSingleState.getValue());
        Assert.assertTrue(build.isNull(1));
    }

    @Test
    public void testPrimitiveLongSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(LongState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(LongState.class);
        LongState createSingleState = generateStateFactory.createSingleState();
        LongState createSingleState2 = generateStateFactory.createSingleState();
        createSingleState.setValue(2L);
        BlockBuilder createBlockBuilder = BigintType.BIGINT.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(createSingleState, createBlockBuilder);
        Block build = createBlockBuilder.build();
        Assert.assertEquals(BigintType.BIGINT.getLong(build, 0), createSingleState.getValue());
        generateStateSerializer.deserialize(build, 0, createSingleState2);
        Assert.assertEquals(createSingleState2.getValue(), createSingleState.getValue());
    }

    @Test
    public void testGetSerializedType() {
        Assert.assertEquals(StateCompiler.generateStateSerializer(LongState.class).getSerializedType(), BigintType.BIGINT);
    }

    @Test
    public void testPrimitiveBooleanSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(BooleanState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(BooleanState.class);
        BooleanState booleanState = (BooleanState) generateStateFactory.createSingleState();
        BooleanState booleanState2 = (BooleanState) generateStateFactory.createSingleState();
        booleanState.setBoolean(true);
        BlockBuilder createBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(booleanState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, booleanState2);
        Assert.assertEquals(booleanState2.isBoolean(), booleanState.isBoolean());
    }

    @Test
    public void testPrimitiveByteSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(ByteState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(ByteState.class);
        ByteState byteState = (ByteState) generateStateFactory.createSingleState();
        ByteState byteState2 = (ByteState) generateStateFactory.createSingleState();
        byteState.setByte((byte) 3);
        BlockBuilder createBlockBuilder = TinyintType.TINYINT.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(byteState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, byteState2);
        Assert.assertEquals(byteState2.getByte(), byteState.getByte());
    }

    @Test
    public void testNonPrimitiveSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(SliceState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(SliceState.class);
        SliceState sliceState = (SliceState) generateStateFactory.createSingleState();
        SliceState sliceState2 = (SliceState) generateStateFactory.createSingleState();
        sliceState.setSlice(null);
        VariableWidthBlockBuilder createBlockBuilder = VarcharType.VARCHAR.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(sliceState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, sliceState2);
        Assert.assertEquals(sliceState2.getSlice(), sliceState.getSlice());
        sliceState.setSlice(Slices.utf8Slice("test"));
        VariableWidthBlockBuilder createBlockBuilder2 = VarcharType.VARCHAR.createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(sliceState, createBlockBuilder2);
        generateStateSerializer.deserialize(createBlockBuilder2.build(), 0, sliceState2);
        Assert.assertEquals(sliceState2.getSlice(), sliceState.getSlice());
    }

    @Test
    public void testVarianceStateSerialization() {
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(VarianceState.class);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(VarianceState.class);
        VarianceState createSingleState = generateStateFactory.createSingleState();
        VarianceState createSingleState2 = generateStateFactory.createSingleState();
        createSingleState.setMean(1.0d);
        createSingleState.setCount(2L);
        createSingleState.setM2(3.0d);
        RowBlockBuilder createBlockBuilder = RowType.anonymous(ImmutableList.of(BigintType.BIGINT, DoubleType.DOUBLE, DoubleType.DOUBLE)).createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(createSingleState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, createSingleState2);
        Assert.assertEquals(createSingleState2.getCount(), createSingleState.getCount());
        Assert.assertEquals(Double.valueOf(createSingleState2.getMean()), Double.valueOf(createSingleState.getMean()));
        Assert.assertEquals(Double.valueOf(createSingleState2.getM2()), Double.valueOf(createSingleState.getM2()));
    }

    @Test
    public void testComplexSerialization() {
        ArrayType arrayType = new ArrayType(BigintType.BIGINT);
        MapType mapType = StructuralTestUtil.mapType(BigintType.BIGINT, VarcharType.VARCHAR);
        ImmutableMap of = ImmutableMap.of("Block", arrayType, "AnotherBlock", mapType);
        AccumulatorStateFactory generateStateFactory = StateCompiler.generateStateFactory(TestComplexState.class, of);
        AccumulatorStateSerializer generateStateSerializer = StateCompiler.generateStateSerializer(TestComplexState.class, of);
        TestComplexState testComplexState = (TestComplexState) generateStateFactory.createSingleState();
        TestComplexState testComplexState2 = (TestComplexState) generateStateFactory.createSingleState();
        testComplexState.setBoolean(true);
        testComplexState.setLong(1L);
        testComplexState.setDouble(2.0d);
        testComplexState.setByte((byte) 3);
        testComplexState.setInt(4);
        testComplexState.setSlice(Slices.utf8Slice("test"));
        testComplexState.setAnotherSlice(toSlice(1.0d, 2.0d, 3.0d));
        testComplexState.setYetAnotherSlice(null);
        testComplexState.setBlock(BlockAssertions.createLongsBlock(45));
        testComplexState.setAnotherBlock(StructuralTestUtil.mapBlockOf(BigintType.BIGINT, VarcharType.VARCHAR, ImmutableMap.of(123L, "testBlock")));
        RowBlockBuilder createBlockBuilder = RowType.anonymous(ImmutableList.of(mapType, VarbinaryType.VARBINARY, arrayType, BooleanType.BOOLEAN, TinyintType.TINYINT, DoubleType.DOUBLE, IntegerType.INTEGER, BigintType.BIGINT, VarbinaryType.VARBINARY, VarbinaryType.VARBINARY)).createBlockBuilder((BlockBuilderStatus) null, 1);
        generateStateSerializer.serialize(testComplexState, createBlockBuilder);
        generateStateSerializer.deserialize(createBlockBuilder.build(), 0, testComplexState2);
        Assert.assertEquals(testComplexState2.getBoolean(), testComplexState.getBoolean());
        Assert.assertEquals(testComplexState2.getLong(), testComplexState.getLong());
        Assert.assertEquals(Double.valueOf(testComplexState2.getDouble()), Double.valueOf(testComplexState.getDouble()));
        Assert.assertEquals(testComplexState2.getByte(), testComplexState.getByte());
        Assert.assertEquals(testComplexState2.getInt(), testComplexState.getInt());
        Assert.assertEquals(testComplexState2.getSlice(), testComplexState.getSlice());
        Assert.assertEquals(testComplexState2.getAnotherSlice(), testComplexState.getAnotherSlice());
        Assert.assertEquals(testComplexState2.getYetAnotherSlice(), testComplexState.getYetAnotherSlice());
        Assert.assertEquals(testComplexState2.getBlock().getLong(0, 0), testComplexState.getBlock().getLong(0, 0));
        Assert.assertEquals(testComplexState2.getAnotherBlock().getLong(0, 0), testComplexState.getAnotherBlock().getLong(0, 0));
        Assert.assertEquals(testComplexState2.getAnotherBlock().getSlice(1, 0, 9), testComplexState.getAnotherBlock().getSlice(1, 0, 9));
    }

    private static long getComplexStateRetainedSize(TestComplexState testComplexState) {
        long instanceSize = SizeOf.instanceSize(testComplexState.getClass());
        try {
            for (Field field : testComplexState.getClass().getDeclaredFields()) {
                Class<?> type = field.getType();
                field.setAccessible(true);
                if (type == BlockBigArray.class || type == BooleanBigArray.class || type == SliceBigArray.class || type == ByteBigArray.class || type == DoubleBigArray.class || type == LongBigArray.class || type == IntBigArray.class) {
                    instanceSize += ((Long) Reflection.methodHandle(type, "sizeOf", new Class[0]).invokeWithArguments(field.get(testComplexState))).longValue();
                }
            }
            return instanceSize;
        } catch (Throwable th) {
            throw new RuntimeException(th);
        }
    }

    private static long getReferenceCountMapOverhead(TestComplexState testComplexState) {
        long j = 0;
        try {
            for (Field field : testComplexState.getClass().getDeclaredFields()) {
                if (field.getType() == BlockBigArray.class || field.getType() == SliceBigArray.class) {
                    field.setAccessible(true);
                    for (Field field2 : field.getType().getDeclaredFields()) {
                        if (field2.getType() == ReferenceCountMap.class) {
                            field2.setAccessible(true);
                            j += ((Long) Reflection.methodHandle(field2.getType(), "sizeOf", new Class[0]).invokeWithArguments(field2.get(field.get(testComplexState)))).longValue();
                        }
                    }
                }
            }
            return j;
        } catch (Throwable th) {
            throw new RuntimeException(th);
        }
    }

    @Test(invocationCount = 100, successPercentage = 90)
    public void testComplexStateEstimatedSize() {
        GroupedAccumulatorState groupedAccumulatorState = (TestComplexState) StateCompiler.generateStateFactory(TestComplexState.class, ImmutableMap.of("Block", new ArrayType(BigintType.BIGINT), "AnotherBlock", StructuralTestUtil.mapType(BigintType.BIGINT, VarcharType.VARCHAR))).createGroupedState();
        long complexStateRetainedSize = getComplexStateRetainedSize(groupedAccumulatorState);
        Assert.assertEquals(groupedAccumulatorState.getEstimatedSize(), complexStateRetainedSize);
        long referenceCountMapOverhead = complexStateRetainedSize - getReferenceCountMapOverhead(groupedAccumulatorState);
        for (int i = 0; i < 1000; i++) {
            groupedAccumulatorState.setGroupId(i);
            groupedAccumulatorState.setBoolean(true);
            groupedAccumulatorState.setLong(1L);
            groupedAccumulatorState.setDouble(2.0d);
            groupedAccumulatorState.setByte((byte) 3);
            groupedAccumulatorState.setInt(4);
            Slice utf8Slice = Slices.utf8Slice("test");
            long retainedSize = 0 + utf8Slice.getRetainedSize();
            groupedAccumulatorState.setSlice(utf8Slice);
            Slice slice = toSlice(1.0d, 2.0d, 3.0d);
            long retainedSize2 = retainedSize + slice.getRetainedSize();
            groupedAccumulatorState.setAnotherSlice(slice);
            groupedAccumulatorState.setYetAnotherSlice(null);
            Block createLongsBlock = BlockAssertions.createLongsBlock(45);
            long retainedSizeInBytes = retainedSize2 + createLongsBlock.getRetainedSizeInBytes();
            groupedAccumulatorState.setBlock(createLongsBlock);
            MapBlockBuilder createBlockBuilder = StructuralTestUtil.mapType(BigintType.BIGINT, VarcharType.VARCHAR).createBlockBuilder((BlockBuilderStatus) null, 1);
            createBlockBuilder.buildEntry((blockBuilder, blockBuilder2) -> {
                BigintType.BIGINT.writeLong(blockBuilder, 123L);
                VarcharType.VARCHAR.writeSlice(blockBuilder2, Slices.utf8Slice("testBlock"));
            });
            Block build = createBlockBuilder.build();
            long retainedSizeInBytes2 = retainedSizeInBytes + build.getRetainedSizeInBytes();
            groupedAccumulatorState.setAnotherBlock(build);
            Assert.assertEquals(groupedAccumulatorState.getEstimatedSize(), referenceCountMapOverhead + (retainedSizeInBytes2 * (i + 1)) + getReferenceCountMapOverhead(groupedAccumulatorState));
        }
        for (int i2 = 0; i2 < 1000; i2++) {
            groupedAccumulatorState.setGroupId(i2);
            groupedAccumulatorState.setBoolean(true);
            groupedAccumulatorState.setLong(1L);
            groupedAccumulatorState.setDouble(2.0d);
            groupedAccumulatorState.setByte((byte) 3);
            groupedAccumulatorState.setInt(4);
            Slice utf8Slice2 = Slices.utf8Slice("test");
            long retainedSize3 = 0 + utf8Slice2.getRetainedSize();
            groupedAccumulatorState.setSlice(utf8Slice2);
            Slice slice2 = toSlice(1.0d, 2.0d, 3.0d);
            long retainedSize4 = retainedSize3 + slice2.getRetainedSize();
            groupedAccumulatorState.setAnotherSlice(slice2);
            groupedAccumulatorState.setYetAnotherSlice(null);
            Block createLongsBlock2 = BlockAssertions.createLongsBlock(45);
            long retainedSizeInBytes3 = retainedSize4 + createLongsBlock2.getRetainedSizeInBytes();
            groupedAccumulatorState.setBlock(createLongsBlock2);
            MapBlockBuilder createBlockBuilder2 = StructuralTestUtil.mapType(BigintType.BIGINT, VarcharType.VARCHAR).createBlockBuilder((BlockBuilderStatus) null, 1);
            createBlockBuilder2.buildEntry((blockBuilder3, blockBuilder4) -> {
                BigintType.BIGINT.writeLong(blockBuilder3, 123L);
                VarcharType.VARCHAR.writeSlice(blockBuilder4, Slices.utf8Slice("testBlock"));
            });
            Block build2 = createBlockBuilder2.build();
            long retainedSizeInBytes4 = retainedSizeInBytes3 + build2.getRetainedSizeInBytes();
            groupedAccumulatorState.setAnotherBlock(build2);
            Assert.assertEquals(groupedAccumulatorState.getEstimatedSize(), referenceCountMapOverhead + (retainedSizeInBytes4 * 1000) + getReferenceCountMapOverhead(groupedAccumulatorState));
        }
    }

    private static Slice toSlice(double... dArr) {
        Slice allocate = Slices.allocate(dArr.length * 8);
        SliceOutput output = allocate.getOutput();
        for (double d : dArr) {
            output.writeDouble(d);
        }
        return allocate;
    }
}
