package io.trino.cost;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Primitives;
import io.trino.spi.statistics.StatsUtil;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import java.time.LocalDate;
import java.util.Map;
import java.util.function.Function;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/cost/TestStatsNormalizer.class */
public class TestStatsNormalizer {
    private final StatsNormalizer normalizer = new StatsNormalizer();

    @Test
    public void testNoCapping() {
        Symbol symbol = new Symbol("a");
        assertNormalized(PlanNodeStatsEstimate.builder().setOutputRowCount(30.0d).addSymbolStatistics(symbol, SymbolStatsEstimate.builder().setDistinctValuesCount(20.0d).build()).build()).symbolStats(symbol, symbolStatsAssertion -> {
            symbolStatsAssertion.distinctValuesCount(20.0d);
        });
    }

    @Test
    public void testDropNonOutputSymbols() {
        Symbol symbol = new Symbol("a");
        Symbol symbol2 = new Symbol("b");
        Symbol symbol3 = new Symbol("c");
        PlanNodeStatsAssertion.assertThat(this.normalizer.normalize(PlanNodeStatsEstimate.builder().setOutputRowCount(40.0d).addSymbolStatistics(symbol, SymbolStatsEstimate.builder().setDistinctValuesCount(20.0d).build()).addSymbolStatistics(symbol2, SymbolStatsEstimate.builder().setDistinctValuesCount(30.0d).build()).addSymbolStatistics(symbol3, SymbolStatsEstimate.unknown()).build(), ImmutableList.of(symbol2, symbol3), TypeProvider.copyOf(ImmutableMap.of(symbol2, BigintType.BIGINT, symbol3, BigintType.BIGINT)))).symbolsWithKnownStats(symbol2).symbolStats(symbol2, symbolStatsAssertion -> {
            symbolStatsAssertion.distinctValuesCount(30.0d);
        });
    }

    @Test
    public void tesCapDistinctValuesByOutputRowCount() {
        Symbol symbol = new Symbol("a");
        Symbol symbol2 = new Symbol("b");
        Symbol symbol3 = new Symbol("c");
        assertNormalized(PlanNodeStatsEstimate.builder().addSymbolStatistics(symbol, SymbolStatsEstimate.builder().setNullsFraction(0.0d).setDistinctValuesCount(20.0d).build()).addSymbolStatistics(symbol2, SymbolStatsEstimate.builder().setNullsFraction(0.4d).setDistinctValuesCount(20.0d).build()).addSymbolStatistics(symbol3, SymbolStatsEstimate.unknown()).setOutputRowCount(10.0d).build()).symbolStats(symbol, symbolStatsAssertion -> {
            symbolStatsAssertion.distinctValuesCount(10.0d);
        }).symbolStats(symbol2, symbolStatsAssertion2 -> {
            symbolStatsAssertion2.distinctValuesCount(8.0d);
        }).symbolStats(symbol3, (v0) -> {
            v0.distinctValuesCountUnknown();
        });
    }

    @Test
    public void testCapDistinctValuesByToDomainRangeLength() {
        testCapDistinctValuesByToDomainRangeLength(IntegerType.INTEGER, 15.0d, 1L, 5L, 5.0d);
        testCapDistinctValuesByToDomainRangeLength(IntegerType.INTEGER, 2.0E10d, 1L, 1000000000L, 1.0E9d);
        testCapDistinctValuesByToDomainRangeLength(IntegerType.INTEGER, 3.0d, 1L, 5L, 3.0d);
        testCapDistinctValuesByToDomainRangeLength(IntegerType.INTEGER, Double.NaN, 1L, 5L, Double.NaN);
        testCapDistinctValuesByToDomainRangeLength(BigintType.BIGINT, 15.0d, 1L, 5L, 5.0d);
        testCapDistinctValuesByToDomainRangeLength(SmallintType.SMALLINT, 15.0d, 1L, 5L, 5.0d);
        testCapDistinctValuesByToDomainRangeLength(TinyintType.TINYINT, 15.0d, 1L, 5L, 5.0d);
        testCapDistinctValuesByToDomainRangeLength(DecimalType.createDecimalType(10, 2), 11.0d, 1L, 1L, 1.0d);
        testCapDistinctValuesByToDomainRangeLength(DecimalType.createDecimalType(10, 2), 13.0d, 101L, 103L, 3.0d);
        testCapDistinctValuesByToDomainRangeLength(DecimalType.createDecimalType(10, 2), 10.0d, 100L, 200L, 10.0d);
        testCapDistinctValuesByToDomainRangeLength(DoubleType.DOUBLE, 42.0d, Double.valueOf(10.1d), Double.valueOf(10.2d), 42.0d);
        testCapDistinctValuesByToDomainRangeLength(DoubleType.DOUBLE, 42.0d, Double.valueOf(10.1d), Double.valueOf(10.1d), 1.0d);
        testCapDistinctValuesByToDomainRangeLength(BooleanType.BOOLEAN, 11.0d, true, true, 1.0d);
        testCapDistinctValuesByToDomainRangeLength(BooleanType.BOOLEAN, 12.0d, false, true, 2.0d);
        testCapDistinctValuesByToDomainRangeLength(DateType.DATE, 12.0d, Long.valueOf(LocalDate.of(2017, 8, 31).toEpochDay()), Long.valueOf(LocalDate.of(2017, 9, 2).toEpochDay()), 3.0d);
    }

    private void testCapDistinctValuesByToDomainRangeLength(Type type, double d, Object obj, Object obj2, double d2) {
        Preconditions.checkArgument(Primitives.wrap(type.getJavaType()).isInstance(obj), "Incorrect class of low value for %s: %s", type, obj.getClass());
        Preconditions.checkArgument(Primitives.wrap(type.getJavaType()).isInstance(obj2), "Incorrect class of low value for %s: %s", type, obj2.getClass());
        Symbol symbol = new Symbol("x");
        SymbolStatsEstimate build = SymbolStatsEstimate.builder().setNullsFraction(0.0d).setDistinctValuesCount(d).setLowValue(asStatsValue(obj, type)).setHighValue(asStatsValue(obj2, type)).build();
        assertNormalized(PlanNodeStatsEstimate.builder().setOutputRowCount(1.0E10d).addSymbolStatistics(symbol, build).build(), TypeProvider.copyOf(ImmutableMap.of(symbol, type))).symbolStats(symbol, symbolStatsAssertion -> {
            symbolStatsAssertion.distinctValuesCount(d2);
        });
        assertNormalized(PlanNodeStatsEstimate.builder().addSymbolStatistics(symbol, build).build(), TypeProvider.copyOf(ImmutableMap.of(symbol, type))).symbolStats(symbol, symbolStatsAssertion2 -> {
            symbolStatsAssertion2.distinctValuesCount(d2);
        });
    }

    private PlanNodeStatsAssertion assertNormalized(PlanNodeStatsEstimate planNodeStatsEstimate) {
        return assertNormalized(planNodeStatsEstimate, TypeProvider.copyOf((Map) planNodeStatsEstimate.getSymbolsWithKnownStatistics().stream().collect(ImmutableMap.toImmutableMap(Function.identity(), symbol -> {
            return BigintType.BIGINT;
        }))));
    }

    private PlanNodeStatsAssertion assertNormalized(PlanNodeStatsEstimate planNodeStatsEstimate, TypeProvider typeProvider) {
        return PlanNodeStatsAssertion.assertThat(this.normalizer.normalize(planNodeStatsEstimate, planNodeStatsEstimate.getSymbolsWithKnownStatistics(), typeProvider));
    }

    private double asStatsValue(Object obj, Type type) {
        return StatsUtil.toStatsRepresentation(type, obj).orElse(Double.NaN);
    }
}
