/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.hive;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import io.trino.block.BlockAssertions;
import io.trino.plugin.hive.HiveBucketFunction;
import io.trino.plugin.hive.HivePartitionedBucketFunction;
import io.trino.plugin.hive.HiveType;
import io.trino.plugin.hive.util.HiveBucketing;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.connector.BucketFunction;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

public class TestHivePartitionedBucketFunction {
    @DataProvider(name="hiveBucketingVersion")
    public static Object[][] hiveBucketingVersion() {
        return new Object[][]{{HiveBucketing.BucketingVersion.BUCKETING_V1}, {HiveBucketing.BucketingVersion.BUCKETING_V2}};
    }

    @Test(dataProvider="hiveBucketingVersion")
    public void testSinglePartition(HiveBucketing.BucketingVersion hiveBucketingVersion) {
        int numValues = 1024;
        int numBuckets = 10;
        Block bucketColumn = TestHivePartitionedBucketFunction.createLongSequenceBlockWithNull(numValues);
        Page bucketedColumnPage = new Page(new Block[]{bucketColumn});
        Block partitionColumn = BlockAssertions.createLongRepeatBlock((int)78758, (int)numValues);
        Page page = new Page(new Block[]{bucketColumn, partitionColumn});
        BucketFunction hiveBucketFunction = TestHivePartitionedBucketFunction.bucketFunction(hiveBucketingVersion, numBuckets, (List<HiveType>)ImmutableList.of((Object)HiveType.HIVE_LONG));
        HashMultimap bucketPositions = HashMultimap.create();
        for (int i = 0; i < numValues; ++i) {
            int hiveBucket = hiveBucketFunction.getBucket(bucketedColumnPage, i);
            bucketPositions.put((Object)hiveBucket, (Object)i);
        }
        BucketFunction hivePartitionedBucketFunction = TestHivePartitionedBucketFunction.partitionedBucketFunction(hiveBucketingVersion, numBuckets, (List<HiveType>)ImmutableList.of((Object)HiveType.HIVE_LONG), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT), 100);
        for (Map.Entry entry : bucketPositions.asMap().entrySet()) {
            TestHivePartitionedBucketFunction.assertBucketCount(hivePartitionedBucketFunction, page, (Collection)entry.getValue(), 1);
        }
        TestHivePartitionedBucketFunction.assertBucketCount(hivePartitionedBucketFunction, page, (Collection)IntStream.range(0, numValues).boxed().collect(ImmutableList.toImmutableList()), numBuckets);
    }

    @Test(dataProvider="hiveBucketingVersion")
    public void testMultiplePartitions(HiveBucketing.BucketingVersion hiveBucketingVersion) {
        int numValues = 1024;
        int numBuckets = 10;
        Block bucketColumn = TestHivePartitionedBucketFunction.createLongSequenceBlockWithNull(numValues);
        Page bucketedColumnPage = new Page(new Block[]{bucketColumn});
        BucketFunction hiveBucketFunction = TestHivePartitionedBucketFunction.bucketFunction(hiveBucketingVersion, numBuckets, (List<HiveType>)ImmutableList.of((Object)HiveType.HIVE_LONG));
        int numPartitions = 8;
        ArrayList<Long> partitionValues = new ArrayList<Long>();
        for (int i = 0; i < numPartitions - 1; ++i) {
            partitionValues.addAll(Collections.nCopies(numValues / numPartitions, (long)i * 348349L));
        }
        partitionValues.addAll(Collections.nCopies(numValues / numPartitions, null));
        Block partitionColumn = BlockAssertions.createLongsBlock(partitionValues);
        Page page = new Page(new Block[]{bucketColumn, partitionColumn});
        HashMap<Long, HashMultimap> partitionedBucketPositions = new HashMap<Long, HashMultimap>();
        for (int i = 0; i < numValues; ++i) {
            int hiveBucket = hiveBucketFunction.getBucket(bucketedColumnPage, i);
            Long hivePartition = (Long)partitionValues.get(i);
            partitionedBucketPositions.computeIfAbsent(hivePartition, ignored -> HashMultimap.create()).put((Object)hiveBucket, (Object)i);
        }
        BucketFunction hivePartitionedBucketFunction = TestHivePartitionedBucketFunction.partitionedBucketFunction(hiveBucketingVersion, numBuckets, (List<HiveType>)ImmutableList.of((Object)HiveType.HIVE_LONG), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT), 4000);
        for (Map.Entry partitionEntry : partitionedBucketPositions.entrySet()) {
            for (Map.Entry entry : ((HashMultimap)partitionEntry.getValue()).asMap().entrySet()) {
                TestHivePartitionedBucketFunction.assertBucketCount(hivePartitionedBucketFunction, page, (Collection)entry.getValue(), 1);
            }
        }
        TestHivePartitionedBucketFunction.assertBucketCount(hivePartitionedBucketFunction, page, (Collection)IntStream.range(0, numValues).boxed().collect(ImmutableList.toImmutableList()), numBuckets * numPartitions);
    }

    @Test(dataProvider="hiveBucketingVersion")
    public void testConsecutiveBucketsWithinPartition(HiveBucketing.BucketingVersion hiveBucketingVersion) {
        BlockBuilder bucketColumn = BigintType.BIGINT.createFixedSizeBlockBuilder(10);
        BlockBuilder partitionColumn = BigintType.BIGINT.createFixedSizeBlockBuilder(10);
        for (int i = 0; i < 100; ++i) {
            BigintType.BIGINT.writeLong(bucketColumn, (long)i);
            BigintType.BIGINT.writeLong(partitionColumn, 42L);
        }
        Page page = new Page(new Block[]{bucketColumn.build(), partitionColumn.build()});
        BucketFunction hivePartitionedBucketFunction = TestHivePartitionedBucketFunction.partitionedBucketFunction(hiveBucketingVersion, 10, (List<HiveType>)ImmutableList.of((Object)HiveType.HIVE_LONG), (List<Type>)ImmutableList.of((Object)BigintType.BIGINT), 4000);
        ArrayList<Integer> positions = new ArrayList<Integer>();
        for (int i = 0; i < 100; ++i) {
            positions.add(hivePartitionedBucketFunction.getBucket(page, i));
        }
        int minPosition = (Integer)Collections.min(positions);
        int maxPosition = (Integer)Collections.max(positions);
        Assert.assertEquals((int)(maxPosition - minPosition + 1), (int)10);
    }

    private static void assertBucketCount(BucketFunction bucketFunction, Page page, Collection<Integer> positions, int bucketCount) {
        Assert.assertEquals((long)positions.stream().map(position -> bucketFunction.getBucket(page, position.intValue())).distinct().count(), (long)bucketCount);
    }

    private static Block createLongSequenceBlockWithNull(int numValues) {
        BlockBuilder builder = BigintType.BIGINT.createFixedSizeBlockBuilder(numValues);
        int start = 923402935;
        int end = start + numValues - 1;
        for (int i = start; i < end; ++i) {
            BigintType.BIGINT.writeLong(builder, (long)i);
        }
        builder.appendNull();
        return builder.build();
    }

    private static BucketFunction partitionedBucketFunction(HiveBucketing.BucketingVersion hiveBucketingVersion, int hiveBucketCount, List<HiveType> hiveBucketTypes, List<Type> partitionColumnsTypes, int bucketCount) {
        return new HivePartitionedBucketFunction(hiveBucketingVersion, hiveBucketCount, hiveBucketTypes, partitionColumnsTypes, new TypeOperators(), bucketCount);
    }

    private static BucketFunction bucketFunction(HiveBucketing.BucketingVersion hiveBucketingVersion, int hiveBucketCount, List<HiveType> hiveBucketTypes) {
        return new HiveBucketFunction(hiveBucketingVersion, hiveBucketCount, hiveBucketTypes);
    }
}

