package io.trino.sql.routine;

import com.google.common.base.Throwables;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.FunctionManager;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.PlannerContext;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.tree.FunctionSpecification;
import io.trino.testing.TestingSession;
import io.trino.testing.TransactionBuilder;
import io.trino.transaction.InMemoryTransactionManager;
import io.trino.transaction.TransactionManager;
import io.trino.type.UnknownType;
import java.lang.invoke.MethodHandle;
import java.util.concurrent.atomic.AtomicLong;
import org.assertj.core.api.Assertions;
import org.assertj.core.api.ThrowingConsumer;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/routine/TestSqlFunctions.class */
class TestSqlFunctions {
    private static final SqlParser SQL_PARSER = new SqlParser();
    private static final TransactionManager TRANSACTION_MANAGER = InMemoryTransactionManager.createTestTransactionManager();
    private static final PlannerContext PLANNER_CONTEXT = TestingPlannerContext.plannerContextBuilder().withTransactionManager(TRANSACTION_MANAGER).build();
    private static final Session SESSION = TestingSession.testSessionBuilder().build();
    private final AtomicLong nextId = new AtomicLong();

    TestSqlFunctions() {
    }

    @Test
    void testConstantReturn() {
        assertFunction("FUNCTION answer()\nRETURNS BIGINT\nRETURN 42\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke()).isEqualTo(42L);
        });
    }

    @Test
    void testSimpleReturn() {
        assertFunction("FUNCTION hello(s VARCHAR)\nRETURNS VARCHAR\nRETURN 'Hello, ' || s || '!'\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke(Slices.utf8Slice("world"))).isEqualTo(Slices.utf8Slice("Hello, world!"));
            Assertions.assertThat((Object) methodHandle.invoke(Slices.utf8Slice("WORLD"))).isEqualTo(Slices.utf8Slice("Hello, WORLD!"));
        });
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("foo"), VarcharType.VARCHAR, "Hello, foo!", "'Hello, ' || p || '!'");
    }

    @Test
    void testSimpleExpression() {
        assertFunction("FUNCTION test(a bigint)\nRETURNS bigint\nBEGIN\n  DECLARE x bigint DEFAULT CAST(99 AS bigint);\n  RETURN x * a;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke(0L)).isEqualTo(0L);
            Assertions.assertThat((Object) methodHandle.invoke(1L)).isEqualTo(99L);
            Assertions.assertThat((Object) methodHandle.invoke(42L)).isEqualTo(4158L);
            Assertions.assertThat((Object) methodHandle.invoke(123L)).isEqualTo(12177L);
        });
    }

    @Test
    void testSimpleCase() {
        assertFunction("FUNCTION simple_case(a bigint)\nRETURNS varchar\nBEGIN\n  CASE a\n    WHEN 0 THEN RETURN 'zero';\n    WHEN 1 THEN RETURN 'one';\n    WHEN DECIMAL '10.0' THEN RETURN 'ten';\n    WHEN 20.0E0 THEN RETURN 'twenty';\n    ELSE RETURN 'other';\n  END CASE;\n  RETURN NULL;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke(0L)).isEqualTo(Slices.utf8Slice("zero"));
            Assertions.assertThat((Object) methodHandle.invoke(1L)).isEqualTo(Slices.utf8Slice("one"));
            Assertions.assertThat((Object) methodHandle.invoke(10L)).isEqualTo(Slices.utf8Slice("ten"));
            Assertions.assertThat((Object) methodHandle.invoke(20L)).isEqualTo(Slices.utf8Slice("twenty"));
            Assertions.assertThat((Object) methodHandle.invoke(42L)).isEqualTo(Slices.utf8Slice("other"));
        });
    }

    @Test
    void testSearchCase() {
        assertFunction("FUNCTION search_case(a bigint, b bigint)\nRETURNS varchar\nBEGIN\n  CASE\n    WHEN a = 0 THEN RETURN 'zero';\n    WHEN b = 1 THEN RETURN 'one';\n    WHEN a = DECIMAL '10.0' THEN RETURN 'ten';\n    WHEN b = 20.0E0 THEN RETURN 'twenty';\n    ELSE RETURN 'other';\n  END CASE;\n  RETURN NULL;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke(0L, 42L)).isEqualTo(Slices.utf8Slice("zero"));
            Assertions.assertThat((Object) methodHandle.invoke(42L, 1L)).isEqualTo(Slices.utf8Slice("one"));
            Assertions.assertThat((Object) methodHandle.invoke(10L, 42L)).isEqualTo(Slices.utf8Slice("ten"));
            Assertions.assertThat((Object) methodHandle.invoke(42L, 20L)).isEqualTo(Slices.utf8Slice("twenty"));
            Assertions.assertThat((Object) methodHandle.invoke(42L, 42L)).isEqualTo(Slices.utf8Slice("other"));
            Assertions.assertThat((Object) methodHandle.invoke(0L, 1L)).isEqualTo(Slices.utf8Slice("zero"));
            Assertions.assertThat((Object) methodHandle.invoke(10L, 1L)).isEqualTo(Slices.utf8Slice("one"));
            Assertions.assertThat((Object) methodHandle.invoke(10L, 20L)).isEqualTo(Slices.utf8Slice("ten"));
            Assertions.assertThat((Object) methodHandle.invoke(42L, 20L)).isEqualTo(Slices.utf8Slice("twenty"));
        });
    }

    @Test
    void testFibonacciWhileLoop() {
        assertFunction("FUNCTION fib(n bigint)\nRETURNS bigint\nBEGIN\n  DECLARE a, b bigint DEFAULT 1;\n  DECLARE c bigint;\n  IF n <= 2 THEN\n    RETURN 1;\n  END IF;\n  WHILE n > 2 DO\n    SET n = n - 1;\n    SET c = a + b;\n    SET a = b;\n    SET b = c;\n  END WHILE;\n  RETURN c;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke(1L)).isEqualTo(1L);
            Assertions.assertThat((Object) methodHandle.invoke(2L)).isEqualTo(1L);
            Assertions.assertThat((Object) methodHandle.invoke(3L)).isEqualTo(2L);
            Assertions.assertThat((Object) methodHandle.invoke(4L)).isEqualTo(3L);
            Assertions.assertThat((Object) methodHandle.invoke(5L)).isEqualTo(5L);
            Assertions.assertThat((Object) methodHandle.invoke(6L)).isEqualTo(8L);
            Assertions.assertThat((Object) methodHandle.invoke(7L)).isEqualTo(13L);
            Assertions.assertThat((Object) methodHandle.invoke(8L)).isEqualTo(21L);
        });
    }

    @Test
    void testBreakContinue() {
        assertFunction("FUNCTION test()\nRETURNS bigint\nBEGIN\n  DECLARE a, b int DEFAULT 0;\n  top: WHILE a < 10 DO\n    SET a = a + 1;\n    IF a < 3 THEN\n      ITERATE top;\n    END IF;\n    SET b = b + 1;\n    IF a > 6 THEN\n      LEAVE top;\n    END IF;\n  END WHILE;\n  RETURN b;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke()).isEqualTo(5L);
        });
    }

    @Test
    void testRepeat() {
        assertFunction("FUNCTION test_repeat(a bigint)\nRETURNS bigint\nBEGIN\n  REPEAT\n    SET a = a + 1;\n  UNTIL a >= 10 END REPEAT;\n  RETURN a;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke(0L)).isEqualTo(10L);
            Assertions.assertThat((Object) methodHandle.invoke(100L)).isEqualTo(101L);
        });
    }

    @Test
    void testRepeatContinue() {
        assertFunction("FUNCTION test_repeat_continue()\nRETURNS bigint\nBEGIN\n  DECLARE a int DEFAULT 0;\n  DECLARE b int DEFAULT 0;\n  top: REPEAT\n    SET a = a + 1;\n    IF a <= 3 THEN\n      ITERATE top;\n    END IF;\n    SET b = b + 1;\n  UNTIL a >= 10 END REPEAT;\n  RETURN b;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke()).isEqualTo(7L);
        });
    }

    @Test
    void testReuseLabels() {
        assertFunction("FUNCTION test()\nRETURNS int\nBEGIN\n  DECLARE r int DEFAULT 0;\n  abc: LOOP\n    SET r = r + 1;\n    LEAVE abc;\n  END LOOP;\n  abc: LOOP\n    SET r = r + 1;\n    LEAVE abc;\n  END LOOP;\n  RETURN r;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke()).isEqualTo(2L);
        });
    }

    @Test
    void testReuseVariables() {
        assertFunction("FUNCTION test()\nRETURNS bigint\nBEGIN\n  DECLARE r bigint DEFAULT 0;\n  BEGIN\n    DECLARE x varchar DEFAULT 'hello';\n    SET r = r + length(x);\n  END;\n  BEGIN\n    DECLARE x array(int) DEFAULT array[1, 2, 3];\n    SET r = r + cardinality(x);\n  END;\n  RETURN r;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke()).isEqualTo(8L);
        });
    }

    @Test
    void testAssignParameter() {
        assertFunction("FUNCTION test(x int)\nRETURNS int\nBEGIN\n  SET x = x * 3;\n  RETURN x;\nEND\n", methodHandle -> {
            Assertions.assertThat((Object) methodHandle.invoke(2L)).isEqualTo(6L);
        });
    }

    @Test
    void testCall() {
        testSingleExpression(BigintType.BIGINT, -123L, BigintType.BIGINT, 123L, "abs(p)");
    }

    @Test
    void testCallNested() {
        testSingleExpression(BigintType.BIGINT, -123L, BigintType.BIGINT, 123L, "abs(ceiling(p))");
        testSingleExpression(BigintType.BIGINT, 42L, DoubleType.DOUBLE, Double.valueOf(42.0d), "to_unixTime(from_unixtime(p))");
    }

    @Test
    void testArray() {
        testSingleExpression(BigintType.BIGINT, 3L, BigintType.BIGINT, 5L, "array[3,4,5,6,7][p]");
        testSingleExpression(BigintType.BIGINT, 0L, BigintType.BIGINT, 0L, "array_sort(array[3,2,4,5,1,p])[1]");
    }

    @Test
    void testRow() {
        testSingleExpression(BigintType.BIGINT, 8L, BigintType.BIGINT, 8L, "ROW(1, 'a', p)[3]");
    }

    @Test
    void testLambda() {
        testSingleExpression(BigintType.BIGINT, 3L, BigintType.BIGINT, 9L, "(transform(ARRAY [5, 6], x -> x + p)[2])", false);
    }

    @Test
    void testTry() {
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("42"), BigintType.BIGINT, 42L, "try(cast(p AS bigint))");
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("abc"), BigintType.BIGINT, null, "try(cast(p AS bigint))");
    }

    @Test
    void testTryCast() {
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("42"), BigintType.BIGINT, 42L, "try_cast(p AS bigint)");
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("abc"), BigintType.BIGINT, null, "try_cast(p AS bigint)");
    }

    @Test
    void testNonCanonical() {
        testSingleExpression(BigintType.BIGINT, 100000L, BigintType.BIGINT, 1970L, "EXTRACT(YEAR FROM from_unixtime(p))");
    }

    @Test
    void testAtTimeZone() {
        testSingleExpression(UnknownType.UNKNOWN, null, VarcharType.VARCHAR, "2012-10-30 18:00:00 America/Los_Angeles", "CAST(TIMESTAMP '2012-10-31 01:00 UTC' AT TIME ZONE 'America/Los_Angeles' AS VARCHAR)");
    }

    @Test
    void testSession() {
        testSingleExpression(UnknownType.UNKNOWN, null, DoubleType.DOUBLE, Double.valueOf(Math.floor(SESSION.getStart().toEpochMilli() / 1000.0d)), "floor(to_unixtime(localtimestamp))");
        testSingleExpression(UnknownType.UNKNOWN, null, VarcharType.VARCHAR, SESSION.getUser(), "current_user");
    }

    @Test
    void testSpecialType() {
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("abc"), BooleanType.BOOLEAN, true, "(p LIKE '%bc')");
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("xb"), BooleanType.BOOLEAN, false, "(p LIKE '%bc')");
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("abc"), BooleanType.BOOLEAN, false, "regexp_like(p, '\\d')");
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("123"), BooleanType.BOOLEAN, true, "regexp_like(p, '\\d')");
        testSingleExpression(VarcharType.VARCHAR, Slices.utf8Slice("[4,5,6]"), VarcharType.VARCHAR, "6", "json_extract_scalar(p, '$[2]')");
    }

    private void testSingleExpression(Type type, Object obj, Type type2, Object obj2, String str) {
        testSingleExpression(type, obj, type2, obj2, str, true);
    }

    private void testSingleExpression(Type type, Object obj, Type type2, Object obj2, String str, boolean z) {
        Object[] objArr = new Object[5];
        objArr[0] = "test" + this.nextId.incrementAndGet();
        objArr[1] = type.getTypeSignature();
        objArr[2] = type2.getTypeSignature();
        objArr[3] = z ? "DETERMINISTIC" : "NOT DETERMINISTIC";
        objArr[4] = str;
        assertFunction("FUNCTION %s(p %s)\nRETURNS %s\n%s\nRETURN %s".formatted(objArr), methodHandle -> {
            Object invoke = (Object) methodHandle.invoke(obj);
            if ((type2 instanceof VarcharType) && (invoke instanceof Slice)) {
                invoke = ((Slice) invoke).toStringUtf8();
            }
            Assertions.assertThat(invoke).isEqualTo(obj2);
        });
    }

    private static void assertFunction(@Language("SQL") String str, ThrowingConsumer<MethodHandle> throwingConsumer) {
        TransactionBuilder.transaction(TRANSACTION_MANAGER, PLANNER_CONTEXT.getMetadata(), new AllowAllAccessControl()).singleStatement().execute(SESSION, session -> {
            ScalarFunctionImplementation compileFunction = compileFunction(str, session);
            throwingConsumer.accept(compileFunction.getMethodHandle().bindTo(getInstance(compileFunction)).bindTo(session.toConnectorSession()));
        });
    }

    private static Object getInstance(ScalarFunctionImplementation scalarFunctionImplementation) {
        try {
            return (Object) ((MethodHandle) scalarFunctionImplementation.getInstanceFactory().orElseThrow()).invoke();
        } catch (Throwable th) {
            Throwables.throwIfUnchecked(th);
            throw new RuntimeException(th);
        }
    }

    private static ScalarFunctionImplementation compileFunction(@Language("SQL") String str, Session session) {
        FunctionSpecification createFunctionSpecification = SQL_PARSER.createFunctionSpecification(str);
        FunctionMetadata extractFunctionMetadata = SqlRoutineAnalyzer.extractFunctionMetadata(new FunctionId("test"), createFunctionSpecification);
        return new SqlRoutineCompiler(FunctionManager.createTestingFunctionManager()).compile(new SqlRoutinePlanner(PLANNER_CONTEXT, WarningCollector.NOOP).planSqlFunction(session, createFunctionSpecification, new SqlRoutineAnalyzer(PLANNER_CONTEXT, WarningCollector.NOOP).analyze(session, new AllowAllAccessControl(), createFunctionSpecification))).getScalarFunctionImplementation(new InvocationConvention(extractFunctionMetadata.getFunctionNullability().getArgumentNullable().stream().map(bool -> {
            return bool.booleanValue() ? InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE : InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
        }).toList(), extractFunctionMetadata.getFunctionNullability().isReturnNullable() ? InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN : InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, true, true));
    }
}
