package io.trino.operator.scalar;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.InternalFunctionBundle;
import io.trino.metadata.SqlFunction;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TimeZoneKey;
import io.trino.spi.type.VarcharType;
import io.trino.sql.query.QueryAssertions;
import io.trino.testing.assertions.TrinoExceptionAssert;
import io.trino.util.StructuralTestUtil;
import java.util.Objects;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/operator/scalar/TestLambdaExpression.class */
public class TestLambdaExpression {
    private QueryAssertions assertions;

    @BeforeAll
    public void init() {
        this.assertions = new QueryAssertions();
        this.assertions.addFunctions(new InternalFunctionBundle(new SqlFunction[]{ApplyFunction.APPLY_FUNCTION, InvokeFunction.INVOKE_FUNCTION}));
    }

    @AfterAll
    public void teardown() {
        this.assertions.close();
        this.assertions = null;
    }

    @Test
    public void testBasic() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 1)").binding("a", "5"))).isEqualTo((Object) 6);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 1)").binding("a", "5 + RANDOM(1)"))).isEqualTo((Object) 6);
    }

    @Test
    public void testParameterName() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, " + quote("a.b c; d ' \n \\n \"") + " -> " + quote("a.b c; d ' \n \\n \"") + " * 2)").binding("a", "5"))).isEqualTo((Object) 10);
    }

    @Test
    public void testNull() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 1)").binding("a", "3"))).isEqualTo((Object) 4);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 1)").binding("a", "NULL"))).isNull(IntegerType.INTEGER);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 1)").binding("a", "CAST (NULL AS INTEGER)"))).isNull(IntegerType.INTEGER);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x IS NULL)").binding("a", "3"))).isEqualTo((Object) false);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x IS NULL)").binding("a", "NULL"))).isEqualTo((Object) true);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x IS NULL)").binding("a", "CAST (NULL AS INTEGER)"))).isEqualTo((Object) true);
    }

    @Test
    public void testUnreferencedLambdaArgument() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> 6)").binding("a", "5"))).isEqualTo((Object) 6);
    }

    @Test
    public void testLambdaWithoutArgument() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("invoke(() -> 42)"))).isEqualTo((Object) 42);
    }

    @Test
    public void testSessionDependent() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x || current_timezone())", this.assertions.sessionBuilder().setTimeZoneKey(TimeZoneKey.getTimeZoneKey("Pacific/Kiritimati")).build()).binding("a", "'timezone: '"))).hasType(VarcharType.VARCHAR).isEqualTo("timezone: Pacific/Kiritimati");
    }

    @Test
    public void testInstanceFunction() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> concat(ARRAY [1], x))").binding("a", "ARRAY[2]"))).hasType(new ArrayType(IntegerType.INTEGER)).isEqualTo(ImmutableList.of(1, 2));
    }

    @Test
    public void testNestedLambda() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> apply(x + 7, y -> apply(y * 3, z -> z * 5) + 1) * 2)").binding("a", "11"))).isEqualTo((Object) 542);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> apply(x + 7, x -> apply(x * 3, x -> x * 5) + 1) * 2)").binding("a", "11"))).isEqualTo((Object) 542);
    }

    @Test
    public void testRowAccess() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(CAST(a AS ROW(x INTEGER, y VARCHAR)), r -> r[1])").binding("a", "ROW(1, 'a')"))).isEqualTo((Object) 1);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, r -> r[2])").binding("a", "CAST(ROW(1, 'a') AS ROW(x INTEGER, y VARCHAR))"))).isEqualTo("a");
    }

    @Test
    public void testBind() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, \"$internal$bind\"(b, (x, y) -> x + y))").binding("a", "90").binding("b", "9"))).isEqualTo((Object) 99);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("invoke(\"$internal$bind\"(a, x -> x + 1))").binding("a", "8"))).isEqualTo((Object) 9);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, \"$internal$bind\"(b, c, (x, y, z) -> x + y + z))").binding("a", "900").binding("b", "90").binding("c", "9"))).isEqualTo((Object) 999);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("invoke(\"$internal$bind\"(a, b, (x, y) -> x + y))").binding("a", "90").binding("b", "9"))).isEqualTo((Object) 99);
    }

    @Test
    public void testCoercion() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 9.0E0)").binding("a", "90"))).isEqualTo(Double.valueOf(99.0d));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, \"$internal$bind\"(b, (x, y) -> x + y))").binding("a", "90").binding("b", "9.0E0"))).isEqualTo(Double.valueOf(99.0d));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("invoke(\"$internal$bind\"(a, x -> x + 1.0E0))").binding("a", "8"))).isEqualTo(Double.valueOf(9.0d));
    }

    @Test
    public void testTypeCombinations() {
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 1)").binding("a", "25"))).isEqualTo((Object) 26);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 1.0E0)").binding("a", "25"))).isEqualTo(Double.valueOf(26.0d));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x = 25)").binding("a", "25"))).isEqualTo((Object) true);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> to_base(x, 16))").binding("a", "25"))).hasType(VarcharType.createVarcharType(64)).isEqualTo("19");
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> ARRAY[x + 1])").binding("a", "25"))).hasType(new ArrayType(IntegerType.INTEGER)).isEqualTo(ImmutableList.of(26));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> CAST(x AS BIGINT))").binding("a", "25.6E0"))).isEqualTo((Object) 26L);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x + 1.0E0)").binding("a", "25.6E0"))).isEqualTo(Double.valueOf(26.6d));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x = 25.6E0)").binding("a", "25.6E0"))).isEqualTo((Object) true);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> CAST(x AS VARCHAR))").binding("a", "25.6E0"))).hasType(VarcharType.createUnboundedVarcharType()).isEqualTo("2.56E1");
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> MAP(ARRAY[x + 1], ARRAY[true]))").binding("a", "25.6E0"))).hasType(StructuralTestUtil.mapType(DoubleType.DOUBLE, BooleanType.BOOLEAN)).isEqualTo(ImmutableMap.of(Double.valueOf(26.6d), true));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> if(x, 25, 26))").binding("a", "true"))).isEqualTo((Object) 25);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> if(x, 25.6E0, 28.9E0))").binding("a", "false"))).isEqualTo(Double.valueOf(28.9d));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> not x)").binding("a", "true"))).isEqualTo((Object) false);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> CAST(x AS VARCHAR))").binding("a", "false"))).hasType(VarcharType.createUnboundedVarcharType()).isEqualTo("false");
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> ARRAY[x])").binding("a", "true"))).hasType(new ArrayType(BooleanType.BOOLEAN)).isEqualTo(ImmutableList.of(true));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> from_base(x, 16))").binding("a", "'41'"))).isEqualTo((Object) 65L);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> CAST(x AS DOUBLE))").binding("a", "'25.6E0'"))).isEqualTo(Double.valueOf(25.6d));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> 'abc' = x)").binding("a", "'abc'"))).isEqualTo((Object) true);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x || x)").binding("a", "'abc'"))).hasType(VarcharType.createUnboundedVarcharType()).isEqualTo("abcabc");
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> ROW(x, CAST(x AS INTEGER), x > '0'))").binding("a", "'123'"))).hasType(RowType.anonymous(ImmutableList.of(VarcharType.createVarcharType(3), IntegerType.INTEGER, BooleanType.BOOLEAN))).isEqualTo(ImmutableList.of("123", 123, true));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> from_base(x[3], 10))").binding("a", "ARRAY['abc', NULL, '123']"))).isEqualTo((Object) 123L);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> CAST(x[3] AS DOUBLE))").binding("a", "ARRAY['abc', NULL, '123']"))).isEqualTo(Double.valueOf(123.0d));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x[2] IS NULL)").binding("a", "ARRAY['abc', NULL, '123']"))).isEqualTo((Object) true);
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> x[2])").binding("a", "ARRAY['abc', NULL, '123']"))).isNull(VarcharType.createVarcharType(3));
        ((QueryAssertions.ExpressionAssert) Assertions.assertThat(this.assertions.expression("apply(a, x -> map_keys(x))").binding("a", "MAP(ARRAY['abc', 'def'], ARRAY[123, 456])"))).hasType(new ArrayType(VarcharType.createVarcharType(3))).isEqualTo(ImmutableList.of("abc", "def"));
    }

    @Test
    public void testFunctionParameter() {
        QueryAssertions.ExpressionAssertProvider expression = this.assertions.expression("count(x -> x)");
        Objects.requireNonNull(expression);
        TrinoExceptionAssert.assertTrinoExceptionThrownBy(expression::evaluate).hasErrorCode(new ErrorCodeSupplier[]{StandardErrorCode.FUNCTION_NOT_FOUND}).hasMessage("line 1:12: Unexpected parameters (<function>) for function count. Expected: count(), count(t) T");
        QueryAssertions.ExpressionAssertProvider expression2 = this.assertions.expression("max(x -> x)");
        Objects.requireNonNull(expression2);
        TrinoExceptionAssert.assertTrinoExceptionThrownBy(expression2::evaluate).hasErrorCode(new ErrorCodeSupplier[]{StandardErrorCode.FUNCTION_NOT_FOUND}).hasMessage("line 1:12: Unexpected parameters (<function>) for function max. Expected: max(e, bigint) E:orderable, max(t) T:orderable");
        QueryAssertions.ExpressionAssertProvider expression3 = this.assertions.expression("sqrt(x -> x)");
        Objects.requireNonNull(expression3);
        TrinoExceptionAssert.assertTrinoExceptionThrownBy(expression3::evaluate).hasErrorCode(new ErrorCodeSupplier[]{StandardErrorCode.FUNCTION_NOT_FOUND}).hasMessage("line 1:12: Unexpected parameters (<function>) for function sqrt. Expected: sqrt(double)");
        QueryAssertions.ExpressionAssertProvider expression4 = this.assertions.expression("sqrt(x -> x, 123, x -> x)");
        Objects.requireNonNull(expression4);
        TrinoExceptionAssert.assertTrinoExceptionThrownBy(expression4::evaluate).hasErrorCode(new ErrorCodeSupplier[]{StandardErrorCode.FUNCTION_NOT_FOUND}).hasMessage("line 1:12: Unexpected parameters (<function>, integer, <function>) for function sqrt. Expected: sqrt(double)");
        QueryAssertions.ExpressionAssertProvider expression5 = this.assertions.expression("pow(x -> x, 123)");
        Objects.requireNonNull(expression5);
        TrinoExceptionAssert.assertTrinoExceptionThrownBy(expression5::evaluate).hasErrorCode(new ErrorCodeSupplier[]{StandardErrorCode.FUNCTION_NOT_FOUND}).hasMessage("line 1:12: Unexpected parameters (<function>, integer) for function pow. Expected: pow(double, double)");
        QueryAssertions.ExpressionAssertProvider expression6 = this.assertions.expression("pow(123, x -> x)");
        Objects.requireNonNull(expression6);
        TrinoExceptionAssert.assertTrinoExceptionThrownBy(expression6::evaluate).hasErrorCode(new ErrorCodeSupplier[]{StandardErrorCode.FUNCTION_NOT_FOUND}).hasMessage("line 1:12: Unexpected parameters (integer, <function>) for function pow. Expected: pow(double, double)");
    }

    private static String quote(String str) {
        return "\"" + str.replace("\"", "\"\"") + "\"";
    }
}
