package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig;
import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor;
import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors;
import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.FunctionImplementationType;
import com.facebook.presto.spi.function.FunctionVersion;
import com.facebook.presto.spi.function.Parameter;
import com.facebook.presto.spi.function.RoutineCharacteristics;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.planner.PlanVariableAllocator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.testing.TestingSession;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import org.testng.Assert;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestInlineSqlFunctions.class */
public class TestInlineSqlFunctions {
    private static final RoutineCharacteristics.Language JAVA = new RoutineCharacteristics.Language("java");
    private static final SqlInvokedFunction SQL_FUNCTION_SQUARE = new SqlInvokedFunction(QualifiedObjectName.valueOf(new CatalogSchemaName("unittest", "memory"), "square"), ImmutableList.of(new Parameter("x", TypeSignature.parseTypeSignature("integer"))), TypeSignature.parseTypeSignature("integer"), "square", RoutineCharacteristics.builder().setDeterminism(RoutineCharacteristics.Determinism.DETERMINISTIC).setNullCallClause(RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT).build(), "RETURN x * x", FunctionVersion.notVersioned());
    private static final SqlInvokedFunction THRIFT_FUNCTION_FOO = new SqlInvokedFunction(QualifiedObjectName.valueOf(new CatalogSchemaName("unittest", "memory"), "foo"), ImmutableList.of(new Parameter("x", TypeSignature.parseTypeSignature("integer"))), TypeSignature.parseTypeSignature("integer"), "thrift function foo", RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(RoutineCharacteristics.Determinism.DETERMINISTIC).setNullCallClause(RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT).build(), "", FunctionVersion.notVersioned());
    private static final SqlInvokedFunction SQL_FUNCTION_ADD_1_TO_INT_ARRAY = new SqlInvokedFunction(QualifiedObjectName.valueOf(new CatalogSchemaName("unittest", "memory"), "add_1_int"), ImmutableList.of(new Parameter("x", TypeSignature.parseTypeSignature("array(int)"))), TypeSignature.parseTypeSignature("array(int)"), "add 1 to all elements of array", RoutineCharacteristics.builder().setDeterminism(RoutineCharacteristics.Determinism.DETERMINISTIC).setNullCallClause(RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT).build(), "RETURN transform(x, x -> x + 1)", FunctionVersion.notVersioned());
    private static final SqlInvokedFunction SQL_FUNCTION_ADD_1_TO_BIGINT_ARRAY = new SqlInvokedFunction(QualifiedObjectName.valueOf(new CatalogSchemaName("unittest", "memory"), "add_1_bigint"), ImmutableList.of(new Parameter("x", TypeSignature.parseTypeSignature("array(bigint)"))), TypeSignature.parseTypeSignature("array(bigint)"), "add 1 to all elements of array", RoutineCharacteristics.builder().setDeterminism(RoutineCharacteristics.Determinism.DETERMINISTIC).setNullCallClause(RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT).build(), "RETURN transform(x, x -> x + 1)", FunctionVersion.notVersioned());
    private RuleTester tester;

    @BeforeTest
    public void setup() {
        RuleTester ruleTester = new RuleTester();
        FunctionAndTypeManager functionAndTypeManager = ruleTester.getMetadata().getFunctionAndTypeManager();
        functionAndTypeManager.addFunctionNamespace("unittest", new InMemoryFunctionNamespaceManager("unittest", new SqlFunctionExecutors(ImmutableMap.of(RoutineCharacteristics.Language.SQL, FunctionImplementationType.SQL, JAVA, FunctionImplementationType.THRIFT), new NoopSqlFunctionExecutor()), new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("sql,java")));
        functionAndTypeManager.createFunction(SQL_FUNCTION_SQUARE, true);
        functionAndTypeManager.createFunction(THRIFT_FUNCTION_FOO, true);
        functionAndTypeManager.createFunction(SQL_FUNCTION_ADD_1_TO_INT_ARRAY, true);
        functionAndTypeManager.createFunction(SQL_FUNCTION_ADD_1_TO_BIGINT_ARRAY, true);
        this.tester = ruleTester;
    }

    @Test
    public void testInlineFunction() {
        assertInlined(this.tester, "unittest.memory.square(x)", "x * x", ImmutableMap.of("x", IntegerType.INTEGER));
    }

    @Test
    public void testInlineFunctionInsideFunction() {
        assertInlined(this.tester, "abs(unittest.memory.square(x))", "abs(x * x)", ImmutableMap.of("x", IntegerType.INTEGER));
    }

    @Test
    public void testInlineFunctionContainingLambda() {
        assertInlined(this.tester, "unittest.memory.add_1_int(x)", "transform(x, \"x$lambda\" -> \"x$lambda\" + 1)", ImmutableMap.of("x", new ArrayType(IntegerType.INTEGER)));
    }

    @Test
    public void testInlineSqlFunctionCoercesConstantWithCast() {
        assertInlined(this.tester, "unittest.memory.add_1_bigint(x)", "transform(x, \"x$lambda\" -> \"x$lambda\" + CAST(1 AS bigint))", ImmutableMap.of("x", new ArrayType(BigintType.BIGINT)));
    }

    @Test
    public void testInlineBuiltinSqlFunction() {
        assertInlined(this.tester, "array_sum(x)", "reduce(x, BIGINT '0', (\"s$lambda\", \"x$lambda\") -> \"s$lambda\" + COALESCE(\"x$lambda\", BIGINT '0'), \"s$lambda_0\" -> \"s$lambda_0\")", ImmutableMap.of("x", new ArrayType(IntegerType.INTEGER)));
    }

    @Test
    public void testNoInlineThriftFunction() {
        assertInlined(this.tester, "unittest.memory.foo(x)", "unittest.memory.foo(x)", ImmutableMap.of("x", IntegerType.INTEGER));
    }

    @Test
    public void testInlineFunctionIntoPlan() {
        this.tester.assertThat(new InlineSqlFunctions(this.tester.getMetadata(), this.tester.getSqlParser()).projectExpressionRewrite()).on(planBuilder -> {
            return planBuilder.project(PlanBuilder.assignment(planBuilder.variable("squared"), (Expression) new FunctionCall(QualifiedName.of("unittest", new String[]{"memory", "square"}), ImmutableList.of(new SymbolReference("a")))), planBuilder.values(planBuilder.variable("a", IntegerType.INTEGER)));
        }).matches(PlanMatchPattern.project(ImmutableMap.of("squared", PlanMatchPattern.expression("x * x")), PlanMatchPattern.values((Map<String, Integer>) ImmutableMap.of("x", 0))));
    }

    @Test
    public void testNoInlineIntoPlanWhenInlineIsDisabled() {
        this.tester.assertThat(new InlineSqlFunctions(this.tester.getMetadata(), this.tester.getSqlParser()).projectExpressionRewrite()).setSystemProperty("inline_sql_functions", "false").on(planBuilder -> {
            return planBuilder.project(PlanBuilder.assignment(planBuilder.variable("squared"), (Expression) new FunctionCall(QualifiedName.of("unittest", new String[]{"memory", "square"}), ImmutableList.of(new SymbolReference("a")))), planBuilder.values(planBuilder.variable("a", IntegerType.INTEGER)));
        }).doesNotFire();
    }

    private void assertInlined(RuleTester ruleTester, String str, String str2, Map<String, Type> map) {
        Session build = TestingSession.testSessionBuilder().setSystemProperty("inline_sql_functions", "true").build();
        Metadata metadata = ruleTester.getMetadata();
        Expression expression = PlanBuilder.expression(str);
        Assert.assertEquals(ExpressionUtils.rewriteIdentifiersToSymbolReferences(InlineSqlFunctions.InlineSqlFunctionsRewriter.rewrite(expression, build, metadata, new PlanVariableAllocator((Collection) map.entrySet().stream().map(entry -> {
            return new VariableReferenceExpression(Optional.empty(), (String) entry.getKey(), (Type) entry.getValue());
        }).collect(ImmutableList.toImmutableList())), ExpressionAnalyzer.getExpressionTypes(build, metadata, ruleTester.getSqlParser(), TypeProvider.viewOf(map), expression, ImmutableMap.of(), WarningCollector.NOOP))), PlanBuilder.expression(str2));
    }
}
