package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.FunctionType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.OriginalExpressionUtils;
import com.google.common.collect.ImmutableList;
import java.util.Optional;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:com/facebook/presto/sql/planner/iterative/rule/TestTranslateExpressions.class */
public class TestTranslateExpressions extends BaseRuleTest {
    private static final Metadata METADATA = MetadataManager.createTestMetadataManager();
    private static final FunctionAndTypeManager FUNCTION_MANAGER = METADATA.getFunctionAndTypeManager();
    private static final FunctionResolution FUNCTION_RESOLUTION = new FunctionResolution(FUNCTION_MANAGER);
    private static final FunctionHandle REDUCE_AGG = FUNCTION_MANAGER.lookupFunction("reduce_agg", TypeSignatureProvider.fromTypes(new Type[]{IntegerType.INTEGER, IntegerType.INTEGER, new FunctionType(ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER), IntegerType.INTEGER), new FunctionType(ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER), IntegerType.INTEGER)}));

    public TestTranslateExpressions() {
        super(new Plugin[0]);
    }

    @Test
    public void testTranslateAggregationWithLambda() {
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation) tester().assertThat(new TranslateExpressions(METADATA, new SqlParser()).aggregationRowExpressionRewriteRule()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(Expressions.variable("reduce_agg", IntegerType.INTEGER), new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, IntegerType.INTEGER, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("input")), OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("0")), OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("(x,y) -> x*y")), OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("(a,b) -> a*b")))), Optional.of(OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("input > 10"))), Optional.empty(), false, Optional.empty())).source(planBuilder.values(planBuilder.variable("input", IntegerType.INTEGER)));
            });
        }).get().getAggregations().get(Expressions.variable("reduce_agg", IntegerType.INTEGER));
        Assert.assertEquals(aggregation, new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, IntegerType.INTEGER, ImmutableList.of(Expressions.variable("input", IntegerType.INTEGER), Expressions.constant(0L, IntegerType.INTEGER), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER), ImmutableList.of("x", "y"), multiply(Expressions.variable("x", IntegerType.INTEGER), Expressions.variable("y", IntegerType.INTEGER))), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER), ImmutableList.of("a", "b"), multiply(Expressions.variable("a", IntegerType.INTEGER), Expressions.variable("b", IntegerType.INTEGER))))), Optional.of(greaterThan(Expressions.variable("input", IntegerType.INTEGER), Expressions.constant(10L, IntegerType.INTEGER))), Optional.empty(), false, Optional.empty()));
        Assert.assertFalse(isUntranslated(aggregation));
    }

    @Test
    public void testTranslateIntermediateAggregationWithLambda() {
        AggregationNode.Aggregation aggregation = (AggregationNode.Aggregation) tester().assertThat(new TranslateExpressions(METADATA, new SqlParser()).aggregationRowExpressionRewriteRule()).on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().addAggregation(Expressions.variable("reduce_agg", IntegerType.INTEGER), new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, IntegerType.INTEGER, ImmutableList.of(OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("input")), OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("(x,y) -> x*y")), OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("(a,b) -> a*b")))), Optional.of(OriginalExpressionUtils.castToRowExpression(PlanBuilder.expression("input > 10"))), Optional.empty(), false, Optional.empty())).source(planBuilder.values(planBuilder.variable("input", IntegerType.INTEGER)));
            });
        }).get().getAggregations().get(Expressions.variable("reduce_agg", IntegerType.INTEGER));
        Assert.assertEquals(aggregation, new AggregationNode.Aggregation(new CallExpression("reduce_agg", REDUCE_AGG, IntegerType.INTEGER, ImmutableList.of(Expressions.variable("input", IntegerType.INTEGER), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER), ImmutableList.of("x", "y"), multiply(Expressions.variable("x", IntegerType.INTEGER), Expressions.variable("y", IntegerType.INTEGER))), new LambdaDefinitionExpression(Optional.empty(), ImmutableList.of(IntegerType.INTEGER, IntegerType.INTEGER), ImmutableList.of("a", "b"), multiply(Expressions.variable("a", IntegerType.INTEGER), Expressions.variable("b", IntegerType.INTEGER))))), Optional.of(greaterThan(Expressions.variable("input", IntegerType.INTEGER), Expressions.constant(10L, IntegerType.INTEGER))), Optional.empty(), false, Optional.empty()));
        Assert.assertFalse(isUntranslated(aggregation));
    }

    private CallExpression greaterThan(RowExpression rowExpression, RowExpression rowExpression2) {
        return Expressions.call("GREATER_THAN", FUNCTION_RESOLUTION.comparisonFunction(OperatorType.GREATER_THAN, rowExpression.getType(), rowExpression2.getType()), BooleanType.BOOLEAN, ImmutableList.of(rowExpression, rowExpression2));
    }

    private CallExpression multiply(RowExpression rowExpression, RowExpression rowExpression2) {
        return Expressions.call("MULTIPLY", FUNCTION_RESOLUTION.arithmeticFunction(OperatorType.MULTIPLY, rowExpression.getType(), rowExpression2.getType()), rowExpression.getType(), ImmutableList.of(rowExpression, rowExpression2));
    }

    private static boolean isUntranslated(AggregationNode.Aggregation aggregation) {
        return aggregation.getCall().getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression) || ((Boolean) aggregation.getFilter().map(OriginalExpressionUtils::isExpression).orElse(false)).booleanValue();
    }
}
