package io.trino.operator.aggregation;

import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.IntArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/aggregation/TestAggregationMask.class */
public class TestAggregationMask {
    @Test
    public void testUnsetNulls() {
        AggregationMask createSelectAll = AggregationMask.createSelectAll(0);
        assertAggregationMaskAll(createSelectAll, 0);
        for (int i = 7; i < 10; i++) {
            createSelectAll.reset(i);
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.unselectNullPositions(new IntArrayBlock(i, Optional.empty(), new int[i]));
            assertAggregationMaskAll(createSelectAll, i);
            boolean[] zArr = new boolean[i];
            createSelectAll.unselectNullPositions(new IntArrayBlock(i, Optional.of(zArr), new int[i]));
            assertAggregationMaskAll(createSelectAll, i);
            Arrays.fill(zArr, true);
            zArr[1] = false;
            zArr[3] = false;
            zArr[5] = false;
            createSelectAll.unselectNullPositions(new IntArrayBlock(i, Optional.of(zArr), new int[i]));
            assertAggregationMaskPositions(createSelectAll, i, 1, 3, 5);
            zArr[3] = true;
            createSelectAll.unselectNullPositions(new IntArrayBlock(i, Optional.of(zArr), new int[i]));
            assertAggregationMaskPositions(createSelectAll, i, 1, 5);
            zArr[1] = true;
            zArr[5] = true;
            createSelectAll.unselectNullPositions(new IntArrayBlock(i, Optional.of(zArr), new int[i]));
            assertAggregationMaskPositions(createSelectAll, i, new int[0]);
            createSelectAll.reset(i);
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.unselectNullPositions(RunLengthEncodedBlock.create(new IntArrayBlock(1, Optional.empty(), new int[1]), i));
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.unselectNullPositions(RunLengthEncodedBlock.create(new IntArrayBlock(1, Optional.of(new boolean[]{false}), new int[1]), i));
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.unselectNullPositions(RunLengthEncodedBlock.create(new IntArrayBlock(1, Optional.of(new boolean[]{true}), new int[1]), i));
            assertAggregationMaskPositions(createSelectAll, i, new int[0]);
        }
    }

    @Test
    public void testApplyMask() {
        AggregationMask createSelectAll = AggregationMask.createSelectAll(0);
        assertAggregationMaskAll(createSelectAll, 0);
        for (int i = 7; i < 10; i++) {
            createSelectAll.reset(i);
            assertAggregationMaskAll(createSelectAll, i);
            byte[] bArr = new byte[i];
            Arrays.fill(bArr, (byte) 1);
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.empty(), bArr));
            assertAggregationMaskAll(createSelectAll, i);
            Arrays.fill(bArr, (byte) 0);
            bArr[1] = 1;
            bArr[3] = 1;
            bArr[5] = 1;
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.empty(), bArr));
            assertAggregationMaskPositions(createSelectAll, i, 1, 3, 5);
            bArr[3] = 0;
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.empty(), bArr));
            assertAggregationMaskPositions(createSelectAll, i, 1, 5);
            bArr[1] = 0;
            bArr[5] = 0;
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.empty(), bArr));
            assertAggregationMaskPositions(createSelectAll, i, new int[0]);
            createSelectAll.reset(i);
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.empty(), new byte[]{1}), i));
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.empty(), new byte[]{0}), i));
            assertAggregationMaskPositions(createSelectAll, i, new int[0]);
        }
    }

    @Test
    public void testApplyMaskNulls() {
        AggregationMask createSelectAll = AggregationMask.createSelectAll(0);
        assertAggregationMaskAll(createSelectAll, 0);
        for (int i = 7; i < 10; i++) {
            createSelectAll.reset(i);
            assertAggregationMaskAll(createSelectAll, i);
            byte[] bArr = new byte[i];
            Arrays.fill(bArr, (byte) 1);
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.empty(), bArr));
            assertAggregationMaskAll(createSelectAll, i);
            boolean[] zArr = new boolean[i];
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.of(zArr), bArr));
            assertAggregationMaskAll(createSelectAll, i);
            Arrays.fill(zArr, true);
            zArr[1] = false;
            zArr[3] = false;
            zArr[5] = false;
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.of(zArr), bArr));
            assertAggregationMaskPositions(createSelectAll, i, 1, 3, 5);
            zArr[3] = true;
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.of(zArr), bArr));
            assertAggregationMaskPositions(createSelectAll, i, 1, 5);
            zArr[1] = true;
            zArr[5] = true;
            createSelectAll.applyMaskBlock(new ByteArrayBlock(i, Optional.of(zArr), bArr));
            assertAggregationMaskPositions(createSelectAll, i, new int[0]);
            createSelectAll.reset(i);
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.empty(), new byte[]{1}), i));
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.of(new boolean[]{false}), new byte[]{1}), i));
            assertAggregationMaskAll(createSelectAll, i);
            createSelectAll.applyMaskBlock(RunLengthEncodedBlock.create(new ByteArrayBlock(1, Optional.of(new boolean[]{true}), new byte[]{1}), i));
            assertAggregationMaskPositions(createSelectAll, i, new int[0]);
        }
    }

    private static void assertAggregationMaskAll(AggregationMask aggregationMask, int i) {
        Assertions.assertThat(aggregationMask.isSelectAll()).isTrue();
        Assertions.assertThat(aggregationMask.isSelectNone()).isEqualTo(i == 0);
        Assertions.assertThat(aggregationMask.getPositionCount()).isEqualTo(i);
        Assertions.assertThat(aggregationMask.getSelectedPositionCount()).isEqualTo(i);
        Objects.requireNonNull(aggregationMask);
        Assertions.assertThatThrownBy(aggregationMask::getSelectedPositions).isInstanceOf(IllegalStateException.class);
    }

    private static void assertAggregationMaskPositions(AggregationMask aggregationMask, int i, int... iArr) {
        Assertions.assertThat(aggregationMask.isSelectAll()).isFalse();
        Assertions.assertThat(aggregationMask.isSelectNone()).isEqualTo(iArr.length == 0);
        Assertions.assertThat(aggregationMask.getPositionCount()).isEqualTo(i);
        Assertions.assertThat(aggregationMask.getSelectedPositionCount()).isEqualTo(iArr.length);
        if (iArr.length > 0) {
            Assertions.assertThat(aggregationMask.getSelectedPositions()).startsWith(iArr);
        }
    }
}
