/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import io.trino.cost.PlanNodeStatsAssertion;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SemiJoinStatsCalculator;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.sql.planner.Symbol;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public class TestSemiJoinStatsCalculator {
    private PlanNodeStatsEstimate inputStatistics;
    private SymbolStatsEstimate uStats;
    private SymbolStatsEstimate wStats;
    private SymbolStatsEstimate xStats;
    private SymbolStatsEstimate yStats;
    private SymbolStatsEstimate zStats;
    private SymbolStatsEstimate leftOpenStats;
    private SymbolStatsEstimate rightOpenStats;
    private SymbolStatsEstimate unknownRangeStats;
    private SymbolStatsEstimate emptyRangeStats;
    private SymbolStatsEstimate fractionalNdvStats;
    private Symbol u = new Symbol("u");
    private Symbol w = new Symbol("w");
    private Symbol x = new Symbol("x");
    private Symbol y = new Symbol("y");
    private Symbol z = new Symbol("z");
    private Symbol leftOpen = new Symbol("leftOpen");
    private Symbol rightOpen = new Symbol("rightOpen");
    private Symbol unknownRange = new Symbol("unknownRange");
    private Symbol emptyRange = new Symbol("emptyRange");
    private Symbol unknown = new Symbol("unknown");
    private Symbol fractionalNdv = new Symbol("fractionalNdv");

    @BeforeClass
    public void setUp() {
        this.uStats = SymbolStatsEstimate.builder().setAverageRowSize(8.0).setDistinctValuesCount(300.0).setLowValue(0.0).setHighValue(20.0).setNullsFraction(0.1).build();
        this.wStats = SymbolStatsEstimate.builder().setAverageRowSize(8.0).setDistinctValuesCount(30.0).setLowValue(0.0).setHighValue(20.0).setNullsFraction(0.1).build();
        this.xStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(40.0).setLowValue(-10.0).setHighValue(10.0).setNullsFraction(0.25).build();
        this.yStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(20.0).setLowValue(0.0).setHighValue(5.0).setNullsFraction(0.5).build();
        this.zStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(5.0).setLowValue(-100.0).setHighValue(100.0).setNullsFraction(0.1).build();
        this.leftOpenStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(50.0).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(15.0).setNullsFraction(0.1).build();
        this.rightOpenStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(50.0).setLowValue(-15.0).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1).build();
        this.unknownRangeStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(50.0).setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY).setNullsFraction(0.1).build();
        this.emptyRangeStats = SymbolStatsEstimate.builder().setAverageRowSize(4.0).setDistinctValuesCount(0.0).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(Double.NaN).build();
        this.fractionalNdvStats = SymbolStatsEstimate.builder().setAverageRowSize(Double.NaN).setDistinctValuesCount(0.1).setNullsFraction(0.0).build();
        this.inputStatistics = PlanNodeStatsEstimate.builder().addSymbolStatistics(this.u, this.uStats).addSymbolStatistics(this.w, this.wStats).addSymbolStatistics(this.x, this.xStats).addSymbolStatistics(this.y, this.yStats).addSymbolStatistics(this.z, this.zStats).addSymbolStatistics(this.leftOpen, this.leftOpenStats).addSymbolStatistics(this.rightOpen, this.rightOpenStats).addSymbolStatistics(this.unknownRange, this.unknownRangeStats).addSymbolStatistics(this.emptyRange, this.emptyRangeStats).addSymbolStatistics(this.unknown, SymbolStatsEstimate.unknown()).addSymbolStatistics(this.fractionalNdv, this.fractionalNdvStats).setOutputRowCount(1000.0).build();
    }

    @Test
    public void testSemiJoin() {
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.w)).symbolStats(this.x, stats -> stats.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0).distinctValuesCount(this.wStats.getDistinctValuesCount())).symbolStats(this.w, stats -> stats.isEqualTo(this.wStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction() * (this.wStats.getDistinctValuesCount() / this.xStats.getDistinctValuesCount()));
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.u)).symbolStats(this.x, stats -> stats.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0).distinctValuesCount(this.xStats.getDistinctValuesCount())).symbolStats(this.u, stats -> stats.isEqualTo(this.uStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction());
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.unknown, (Symbol)this.u)).symbolStats(this.unknown, stats -> stats.nullsFraction(0.0).distinctValuesCountUnknown().unknownRange()).symbolStats(this.u, stats -> stats.isEqualTo(this.uStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.unknown)).symbolStats(this.x, stats -> stats.nullsFraction(0.0).lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).distinctValuesCountUnknown()).symbolStatsUnknown(this.unknown).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.emptyRange, (Symbol)this.emptyRange)).outputRowsCount(0.0);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeSemiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.fractionalNdv, (Symbol)this.fractionalNdv)).outputRowsCount(1000.0).symbolStats(this.fractionalNdv, stats -> stats.nullsFraction(0.0).distinctValuesCount(0.1));
    }

    @Test
    public void testAntiJoin() {
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.u, (Symbol)this.x)).symbolStats(this.u, stats -> stats.lowValue(this.uStats.getLowValue()).highValue(this.uStats.getHighValue()).nullsFraction(0.0).distinctValuesCount(this.uStats.getDistinctValuesCount() - this.xStats.getDistinctValuesCount())).symbolStats(this.x, stats -> stats.isEqualTo(this.xStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.uStats.getValuesFraction() * (1.0 - this.xStats.getDistinctValuesCount() / this.uStats.getDistinctValuesCount()));
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.u)).symbolStats(this.x, stats -> stats.lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).nullsFraction(0.0).distinctValuesCount(this.xStats.getDistinctValuesCount() * 0.5)).symbolStats(this.u, stats -> stats.isEqualTo(this.uStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCount(this.inputStatistics.getOutputRowCount() * this.xStats.getValuesFraction() * 0.5);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.unknown, (Symbol)this.u)).symbolStats(this.unknown, stats -> stats.nullsFraction(0.0).distinctValuesCountUnknown().unknownRange()).symbolStats(this.u, stats -> stats.isEqualTo(this.uStats)).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.x, (Symbol)this.unknown)).symbolStats(this.x, stats -> stats.nullsFraction(0.0).lowValue(this.xStats.getLowValue()).highValue(this.xStats.getHighValue()).distinctValuesCountUnknown()).symbolStatsUnknown(this.unknown).symbolStats(this.z, stats -> stats.isEqualTo(this.zStats)).outputRowsCountUnknown();
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.emptyRange, (Symbol)this.emptyRange)).outputRowsCount(0.0);
        PlanNodeStatsAssertion.assertThat(SemiJoinStatsCalculator.computeAntiJoin((PlanNodeStatsEstimate)this.inputStatistics, (PlanNodeStatsEstimate)this.inputStatistics, (Symbol)this.fractionalNdv, (Symbol)this.fractionalNdv)).outputRowsCount(500.0).symbolStats(this.fractionalNdv, stats -> stats.nullsFraction(0.0).distinctValuesCount(0.05));
    }
}

