package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.connector.MockConnectorFactory;
import io.trino.metadata.AbstractMockMetadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.RowChangeParadigm;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.TestingColumnHandle;
import io.trino.spi.expression.Constant;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.RuleTester;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableUpdateNode;
import io.trino.sql.planner.plan.TableWriterNode;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.StringLiteral;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushMergeWriterUpdateIntoConnector.class */
public class TestPushMergeWriterUpdateIntoConnector {
    private static final String TEST_SCHEMA = "test_schema";
    private static final String TEST_TABLE = "test_table";
    private static final SchemaTableName SCHEMA_TABLE_NAME = new SchemaTableName("test_schema", TEST_TABLE);

    @Test
    public void testPushUpdateIntoConnector() {
        ImmutableList of = ImmutableList.of("column_1", "column_2");
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(MockConnectorFactory.builder().build()).build();
        try {
            build.assertThat(createRule(build)).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol("merge_row");
                Symbol symbol2 = planBuilder.symbol("row_id");
                return planBuilder.tableFinish(planBuilder.merge((PlanNode) planBuilder.mergeProcessor(SCHEMA_TABLE_NAME, planBuilder.project(new Assignments(Map.of(symbol, new Row(ImmutableList.of(planBuilder.symbol("column_1").toSymbolReference(), new LongLiteral("1"), new BooleanLiteral("true"), new LongLiteral("1"), new LongLiteral("1"))))), planBuilder.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setAssignments(ImmutableMap.of()).setSymbols(ImmutableList.of()).setTableHandle(build.getCurrentCatalogTableHandle("test_schema", TEST_TABLE)).build();
                })), symbol, symbol2, ImmutableList.of(), ImmutableList.of(), ImmutableList.of()), planBuilder.mergeTarget(SCHEMA_TABLE_NAME, new TableWriterNode.MergeParadigmAndTypes(Optional.of(RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW), ImmutableList.of(), of, IntegerType.INTEGER)), symbol, symbol2, (List<Symbol>) ImmutableList.of()), planBuilder.mergeTarget(SCHEMA_TABLE_NAME), planBuilder.symbol("row_count"));
            }).matches(PlanMatchPattern.node(TableUpdateNode.class, new PlanMatchPattern[0]));
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testPushUpdateIntoConnectorArithmeticExpression() {
        ImmutableList of = ImmutableList.of("column_1", "column_2");
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(MockConnectorFactory.builder().build()).build();
        try {
            build.assertThat(createRule(build)).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol("merge_row");
                Symbol symbol2 = planBuilder.symbol("row_id");
                return planBuilder.tableFinish(planBuilder.merge((PlanNode) planBuilder.mergeProcessor(SCHEMA_TABLE_NAME, planBuilder.project(new Assignments(Map.of(symbol, new Row(ImmutableList.of(planBuilder.symbol("column_1").toSymbolReference(), new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, planBuilder.symbol("col1").toSymbolReference(), new LongLiteral("5")))))), planBuilder.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setAssignments(ImmutableMap.of()).setSymbols(ImmutableList.of()).setTableHandle(build.getCurrentCatalogTableHandle("test_schema", TEST_TABLE)).build();
                })), symbol, symbol2, ImmutableList.of(), ImmutableList.of(), ImmutableList.of()), planBuilder.mergeTarget(SCHEMA_TABLE_NAME, new TableWriterNode.MergeParadigmAndTypes(Optional.of(RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW), ImmutableList.of(), of, IntegerType.INTEGER)), symbol, symbol2, (List<Symbol>) ImmutableList.of()), planBuilder.mergeTarget(SCHEMA_TABLE_NAME), planBuilder.symbol("row_count"));
            }).doesNotFire();
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testPushUpdateIntoConnectorUpdateAll() {
        ImmutableList of = ImmutableList.of("column_1", "column_2");
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(MockConnectorFactory.builder().build()).build();
        try {
            build.assertThat(createRule(build)).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol("merge_row");
                Symbol symbol2 = planBuilder.symbol("row_id");
                return planBuilder.tableFinish(planBuilder.merge((PlanNode) planBuilder.mergeProcessor(SCHEMA_TABLE_NAME, planBuilder.project(new Assignments(Map.of(symbol, new Row(ImmutableList.of(new FunctionCall(build.getMetadata().resolveBuiltinFunction("from_base64", TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR})).toQualifiedName(), ImmutableList.of(new StringLiteral(""))))))), planBuilder.tableScan(tableScanBuilder -> {
                    tableScanBuilder.setAssignments(ImmutableMap.of()).setSymbols(ImmutableList.of()).setTableHandle(build.getCurrentCatalogTableHandle("test_schema", TEST_TABLE)).build();
                })), symbol, symbol2, ImmutableList.of(), ImmutableList.of(), ImmutableList.of()), planBuilder.mergeTarget(SCHEMA_TABLE_NAME, new TableWriterNode.MergeParadigmAndTypes(Optional.of(RowChangeParadigm.DELETE_ROW_AND_INSERT_ROW), ImmutableList.of(), of, IntegerType.INTEGER)), symbol, symbol2, (List<Symbol>) ImmutableList.of()), planBuilder.mergeTarget(SCHEMA_TABLE_NAME), planBuilder.symbol("row_count"));
            }).doesNotFire();
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static PushMergeWriterUpdateIntoConnector createRule(RuleTester ruleTester) {
        return new PushMergeWriterUpdateIntoConnector(ruleTester.getPlannerContext(), ruleTester.getTypeAnalyzer(), new AbstractMockMetadata() { // from class: io.trino.sql.planner.iterative.rule.TestPushMergeWriterUpdateIntoConnector.1
            @Override // io.trino.metadata.AbstractMockMetadata
            public Optional<TableHandle> applyUpdate(Session session, TableHandle tableHandle, Map<ColumnHandle, Constant> map) {
                return Optional.of(tableHandle);
            }

            @Override // io.trino.metadata.AbstractMockMetadata
            public Map<String, ColumnHandle> getColumnHandles(Session session, TableHandle tableHandle) {
                return Map.of("column_1", new TestingColumnHandle("column_1"), "column_2", new TestingColumnHandle("column_2"));
            }
        });
    }
}
