package io.trino.cost;

import io.trino.sql.planner.Symbol;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/cost/TestSemiJoinStatsCalculator.class */
public class TestSemiJoinStatsCalculator {
    private final SymbolStatsEstimate uStats = SymbolStatsEstimate.builder().setAverageRowSize(8.0d).setDistinctValuesCount(300.0d).setLowValue(0.0d).setHighValue(20.0d).setNullsFraction(0.1d).build();
    private final SymbolStatsEstimate wStats = SymbolStatsEstimate.builder().setAverageRowSize(8.0d).setDistinctValuesCount(30.0d).setLowValue(0.0d).setHighValue(20.0d).setNullsFraction(0.1d).build();
    private final SymbolStatsEstimate xStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(40.0d).setLowValue(-10.0d).setHighValue(10.0d).setNullsFraction(0.25d).build();
    private final SymbolStatsEstimate yStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(20.0d).setLowValue(0.0d).setHighValue(5.0d).setNullsFraction(0.5d).build();
    private final SymbolStatsEstimate zStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(5.0d).setLowValue(-100.0d).setHighValue(100.0d).setNullsFraction(0.1d).build();
    private final SymbolStatsEstimate leftOpenStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(50.0d).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(15.0d).setNullsFraction(0.1d).build();
    private final SymbolStatsEstimate rightOpenStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(50.0d).setLowValue(-15.0d).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1d).build();
    private final SymbolStatsEstimate unknownRangeStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(50.0d).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1d).build();
    private final SymbolStatsEstimate emptyRangeStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0d).setDistinctValuesCount(0.0d).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(Double.NaN).build();
    private final SymbolStatsEstimate fractionalNdvStats = SymbolStatsEstimate.builder().setAverageRowSize(Double.NaN).setDistinctValuesCount(0.1d).setNullsFraction(0.0d).build();
    private final Symbol u = new Symbol("u");
    private final Symbol w = new Symbol("w");
    private final Symbol x = new Symbol("x");
    private final Symbol y = new Symbol("y");
    private final Symbol z = new Symbol("z");
    private final Symbol leftOpen = new Symbol("leftOpen");
    private final Symbol rightOpen = new Symbol("rightOpen");
    private final Symbol unknownRange = new Symbol("unknownRange");
    private final Symbol emptyRange = new Symbol("emptyRange");
    private final Symbol unknown = new Symbol("unknown");
    private final Symbol fractionalNdv = new Symbol("fractionalNdv");
    private final PlanNodeStatsEstimate inputStatistics = PlanNodeStatsEstimate.builder().addSymbolStatistics(this.u, this.uStats).addSymbolStatistics(this.w, this.wStats).addSymbolStatistics(this.x, this.xStats).addSymbolStatistics(this.y, this.yStats).addSymbolStatistics(this.z, this.zStats).addSymbolStatistics(this.leftOpen, this.leftOpenStats).addSymbolStatistics(this.rightOpen, this.rightOpenStats).addSymbolStatistics(this.unknownRange, this.unknownRangeStats).addSymbolStatistics(this.emptyRange, this.emptyRangeStats).addSymbolStatistics(this.unknown, SymbolStatsEstimate.unknown()).addSymbolStatistics(this.fractionalNdv, this.fractionalNdvStats).setOutputRowCount(1000.0d).build();

    @Test
    public void testSemiJoin() {
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin(this.inputStatistics, this.inputStatistics, this.x, this.w)).symbolStats(this.x, symbolStatsAssertion -> {
            symbolStatsAssertion.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0d).distinctValuesCount(this.wStats.getDistinctValuesCount());
        }).symbolStats(this.w, symbolStatsAssertion2 -> {
            symbolStatsAssertion2.isEqualTo(this.wStats);
        }).symbolStats(this.z, symbolStatsAssertion3 -> {
            symbolStatsAssertion3.isEqualTo(this.zStats);
        }).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction() * (this.wStats.getDistinctValuesCount() / this.xStats.getDistinctValuesCount()));
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin(this.inputStatistics, this.inputStatistics, this.x, this.u)).symbolStats(this.x, symbolStatsAssertion4 -> {
            symbolStatsAssertion4.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0d).distinctValuesCount(this.xStats.getDistinctValuesCount());
        }).symbolStats(this.u, symbolStatsAssertion5 -> {
            symbolStatsAssertion5.isEqualTo(this.uStats);
        }).symbolStats(this.z, symbolStatsAssertion6 -> {
            symbolStatsAssertion6.isEqualTo(this.zStats);
        }).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction());
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin(this.inputStatistics, this.inputStatistics, this.unknown, this.u)).symbolStats(this.unknown, symbolStatsAssertion7 -> {
            symbolStatsAssertion7.nullsFraction(0.0d).distinctValuesCountUnknown().unknownRange();
        }).symbolStats(this.u, symbolStatsAssertion8 -> {
            symbolStatsAssertion8.isEqualTo(this.uStats);
        }).symbolStats(this.z, symbolStatsAssertion9 -> {
            symbolStatsAssertion9.isEqualTo(this.zStats);
        }).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin(this.inputStatistics, this.inputStatistics, this.x, this.unknown)).symbolStats(this.x, symbolStatsAssertion10 -> {
            symbolStatsAssertion10.nullsFraction(0.0d).lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).distinctValuesCountUnknown();
        }).symbolStatsUnknown(this.unknown).symbolStats(this.z, symbolStatsAssertion11 -> {
            symbolStatsAssertion11.isEqualTo(this.zStats);
        }).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin(this.inputStatistics, this.inputStatistics, this.emptyRange, this.emptyRange)).outputRowsCount(0.0d);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin(this.inputStatistics, this.inputStatistics, this.fractionalNdv, this.fractionalNdv)).outputRowsCount(1000.0d).symbolStats(this.fractionalNdv, symbolStatsAssertion12 -> {
            symbolStatsAssertion12.nullsFraction(0.0d).distinctValuesCount(0.1d);
        });
    }

    @Test
    public void testAntiJoin() {
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin(this.inputStatistics, this.inputStatistics, this.u, this.x)).symbolStats(this.u, symbolStatsAssertion -> {
            symbolStatsAssertion.lowValue(this.uStats.getLowValue()).highValue(this.uStats.getHighValue()).nullsFraction(0.0d).distinctValuesCount(this.uStats.getDistinctValuesCount() - this.xStats.getDistinctValuesCount());
        }).symbolStats(this.x, symbolStatsAssertion2 -> {
            symbolStatsAssertion2.isEqualTo(this.xStats);
        }).symbolStats(this.z, symbolStatsAssertion3 -> {
            symbolStatsAssertion3.isEqualTo(this.zStats);
        }).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.uStats.getValuesFraction() * (1.0d - (this.xStats.getDistinctValuesCount() / this.uStats.getDistinctValuesCount())));
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin(this.inputStatistics, this.inputStatistics, this.x, this.u)).symbolStats(this.x, symbolStatsAssertion4 -> {
            symbolStatsAssertion4.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0d).distinctValuesCount(this.xStats.getDistinctValuesCount() * 0.5d);
        }).symbolStats(this.u, symbolStatsAssertion5 -> {
            symbolStatsAssertion5.isEqualTo(this.uStats);
        }).symbolStats(this.z, symbolStatsAssertion6 -> {
            symbolStatsAssertion6.isEqualTo(this.zStats);
        }).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction() * 0.5d);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin(this.inputStatistics, this.inputStatistics, this.unknown, this.u)).symbolStats(this.unknown, symbolStatsAssertion7 -> {
            symbolStatsAssertion7.nullsFraction(0.0d).distinctValuesCountUnknown().unknownRange();
        }).symbolStats(this.u, symbolStatsAssertion8 -> {
            symbolStatsAssertion8.isEqualTo(this.uStats);
        }).symbolStats(this.z, symbolStatsAssertion9 -> {
            symbolStatsAssertion9.isEqualTo(this.zStats);
        }).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin(this.inputStatistics, this.inputStatistics, this.x, this.unknown)).symbolStats(this.x, symbolStatsAssertion10 -> {
            symbolStatsAssertion10.nullsFraction(0.0d).lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).distinctValuesCountUnknown();
        }).symbolStatsUnknown(this.unknown).symbolStats(this.z, symbolStatsAssertion11 -> {
            symbolStatsAssertion11.isEqualTo(this.zStats);
        }).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin(this.inputStatistics, this.inputStatistics, this.emptyRange, this.emptyRange)).outputRowsCount(0.0d);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin(this.inputStatistics, this.inputStatistics, this.fractionalNdv, this.fractionalNdv)).outputRowsCount(500.0d).symbolStats(this.fractionalNdv, symbolStatsAssertion12 -> {
            symbolStatsAssertion12.nullsFraction(0.0d).distinctValuesCount(0.05d);
        });
    }
}
