package io.trino.sql.planner.iterative.rule;

import io.trino.execution.BaseDataDefinitionTaskTest;
import io.trino.spi.Plugin;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.WriterScalingOptions;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableFinishNode;
import java.util.List;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestRemoveEmptyMergeWriterRuleSet.class */
public class TestRemoveEmptyMergeWriterRuleSet extends BaseRuleTest {
    private CatalogHandle catalogHandle;
    private SchemaTableName schemaTableName;

    public TestRemoveEmptyMergeWriterRuleSet() {
        super(new Plugin[0]);
    }

    @BeforeAll
    public void setup() {
        this.catalogHandle = tester().getCurrentCatalogHandle();
        this.schemaTableName = new SchemaTableName(BaseDataDefinitionTaskTest.SCHEMA, "table");
    }

    @Test
    public void testRemoveEmptyMergeRewrite() {
        testRemoveEmptyMergeRewrite(RemoveEmptyMergeWriterRuleSet.removeEmptyMergeWriterRule(), false);
    }

    @Test
    public void testRemoveEmptyMergeRewriteWithExchange() {
        testRemoveEmptyMergeRewrite(RemoveEmptyMergeWriterRuleSet.removeEmptyMergeWriterWithExchangeRule(), true);
    }

    private void testRemoveEmptyMergeRewrite(Rule<TableFinishNode> rule, boolean z) {
        tester().assertThat(rule).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("merge_row");
            Symbol symbol2 = planBuilder.symbol("row_id");
            Symbol symbol3 = planBuilder.symbol("row_count");
            ExchangeNode merge = planBuilder.merge(this.schemaTableName, planBuilder.exchange(exchangeBuilder -> {
                exchangeBuilder.addSource(planBuilder.project(Assignments.builder().putIdentity(symbol).putIdentity(symbol2).putIdentity(symbol3).build(), planBuilder.values(symbol, symbol2, symbol3))).addInputsSet(symbol, symbol2, symbol3).partitioningScheme(new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, List.of()), List.of(symbol, symbol2, symbol3)));
            }), symbol, symbol2, List.of(symbol3));
            return planBuilder.tableFinish(z ? withExchange(planBuilder, merge, symbol3) : merge, planBuilder.createTarget(this.catalogHandle, this.schemaTableName, true, WriterScalingOptions.ENABLED), symbol3);
        }).matches(PlanMatchPattern.values("A"));
    }

    private ExchangeNode withExchange(PlanBuilder planBuilder, PlanNode planNode, Symbol symbol) {
        return planBuilder.exchange(exchangeBuilder -> {
            exchangeBuilder.addSource(planNode).addInputsSet(symbol).partitioningScheme(new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, List.of()), List.of(symbol)));
        });
    }
}
