package io.trino.sql.routine;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.MoreCollectors;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.DoWhileLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.control.WhileLoop;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.Constant;
import io.airlift.bytecode.instruction.LabelNode;
import io.trino.metadata.FunctionManager;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.ScalarFunctionAdapter;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.gen.BytecodeUtils;
import io.trino.sql.gen.CachedInstanceBinder;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.LambdaBytecodeGenerator;
import io.trino.sql.gen.LambdaExpressionExtractor;
import io.trino.sql.gen.RowExpressionCompiler;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.RowExpressionVisitor;
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.VariableReferenceExpression;
import io.trino.sql.routine.ir.DefaultIrNodeVisitor;
import io.trino.sql.routine.ir.IrBlock;
import io.trino.sql.routine.ir.IrBreak;
import io.trino.sql.routine.ir.IrContinue;
import io.trino.sql.routine.ir.IrIf;
import io.trino.sql.routine.ir.IrLabel;
import io.trino.sql.routine.ir.IrLoop;
import io.trino.sql.routine.ir.IrNode;
import io.trino.sql.routine.ir.IrNodeVisitor;
import io.trino.sql.routine.ir.IrRepeat;
import io.trino.sql.routine.ir.IrReturn;
import io.trino.sql.routine.ir.IrRoutine;
import io.trino.sql.routine.ir.IrSet;
import io.trino.sql.routine.ir.IrStatement;
import io.trino.sql.routine.ir.IrVariable;
import io.trino.sql.routine.ir.IrWhile;
import io.trino.util.CompilerUtils;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:io/trino/sql/routine/SqlRoutineCompiler.class */
public final class SqlRoutineCompiler {
    private final FunctionManager functionManager;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/routine/SqlRoutineCompiler$BytecodeVisitor.class */
    public class BytecodeVisitor implements IrNodeVisitor<Scope, BytecodeNode> {
        private final CachedInstanceBinder cachedInstanceBinder;
        private final Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> compiledLambdaMap;
        private final Map<IrVariable, Variable> variables;
        private final Map<IrLabel, LabelNode> continueLabels = new IdentityHashMap();
        private final Map<IrLabel, LabelNode> breakLabels = new IdentityHashMap();

        public BytecodeVisitor(CachedInstanceBinder cachedInstanceBinder, Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> map, Map<IrVariable, Variable> map2) {
            this.cachedInstanceBinder = (CachedInstanceBinder) Objects.requireNonNull(cachedInstanceBinder, "cachedInstanceBinder is null");
            this.compiledLambdaMap = (Map) Objects.requireNonNull(map, "compiledLambdaMap is null");
            this.variables = (Map) Objects.requireNonNull(map2, "variables is null");
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitNode(IrNode irNode, Scope scope) {
            throw new VerifyException("Unsupported node: " + irNode.getClass().getSimpleName());
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitRoutine(IrRoutine irRoutine, Scope scope) {
            return process(irRoutine.body(), scope);
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitSet(IrSet irSet, Scope scope) {
            return new BytecodeBlock().append(compile(irSet.value(), scope)).putVariable(this.variables.get(irSet.target()));
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitBlock(IrBlock irBlock, Scope scope) {
            BytecodeBlock bytecodeBlock = new BytecodeBlock();
            for (IrVariable irVariable : irBlock.variables()) {
                bytecodeBlock.append(compile(irVariable.defaultValue(), scope)).putVariable(this.variables.get(irVariable));
            }
            LabelNode labelNode = new LabelNode("continue");
            LabelNode labelNode2 = new LabelNode("break");
            if (irBlock.label().isPresent()) {
                this.continueLabels.put(irBlock.label().get(), labelNode);
                this.breakLabels.put(irBlock.label().get(), labelNode2);
                bytecodeBlock.visitLabel(labelNode);
            }
            Iterator<IrStatement> it = irBlock.statements().iterator();
            while (it.hasNext()) {
                bytecodeBlock.append(process(it.next(), scope));
            }
            if (irBlock.label().isPresent()) {
                bytecodeBlock.visitLabel(labelNode2);
            }
            return bytecodeBlock;
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitReturn(IrReturn irReturn, Scope scope) {
            return new BytecodeBlock().append(compile(irReturn.value(), scope)).ret(Primitives.wrap(irReturn.value().getType().getJavaType()));
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitContinue(IrContinue irContinue, Scope scope) {
            LabelNode labelNode = this.continueLabels.get(irContinue.target());
            Verify.verify(labelNode != null, "continue target does not exist", new Object[0]);
            return new BytecodeBlock().gotoLabel(labelNode);
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitBreak(IrBreak irBreak, Scope scope) {
            LabelNode labelNode = this.breakLabels.get(irBreak.target());
            Verify.verify(labelNode != null, "break target does not exist", new Object[0]);
            return new BytecodeBlock().gotoLabel(labelNode);
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitIf(IrIf irIf, Scope scope) {
            IfStatement ifTrue = new IfStatement().condition(compileBoolean(irIf.condition(), scope)).ifTrue(process(irIf.ifTrue(), scope));
            if (irIf.ifFalse().isPresent()) {
                ifTrue.ifFalse(process(irIf.ifFalse().get(), scope));
            }
            return ifTrue;
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitWhile(IrWhile irWhile, Scope scope) {
            return compileLoop(scope, irWhile.label(), bytecodeBlock -> {
                return new WhileLoop().condition(compileBoolean(irWhile.condition(), scope)).body(new BytecodeBlock().append(bytecodeBlock).append(process(irWhile.body(), scope)));
            });
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitRepeat(IrRepeat irRepeat, Scope scope) {
            return compileLoop(scope, irRepeat.label(), bytecodeBlock -> {
                return new DoWhileLoop().condition(not(compileBoolean(irRepeat.condition(), scope))).body(new BytecodeBlock().append(bytecodeBlock).append(process(irRepeat.block(), scope)));
            });
        }

        @Override // io.trino.sql.routine.ir.IrNodeVisitor
        public BytecodeNode visitLoop(IrLoop irLoop, Scope scope) {
            return compileLoop(scope, irLoop.label(), bytecodeBlock -> {
                return new WhileLoop().condition(Constant.loadBoolean(true)).body(new BytecodeBlock().append(bytecodeBlock).append(process(irLoop.block(), scope)));
            });
        }

        private BytecodeNode compileLoop(Scope scope, Optional<IrLabel> optional, Function<BytecodeBlock, BytecodeNode> function) {
            BytecodeBlock bytecodeBlock = new BytecodeBlock();
            Variable createTempVariable = scope.createTempVariable(Integer.TYPE);
            bytecodeBlock.putVariable(createTempVariable, 0);
            BytecodeBlock append = new BytecodeBlock().append(createTempVariable.increment()).append(new IfStatement().condition(BytecodeExpressions.greaterThanOrEqual(createTempVariable, BytecodeExpressions.constantInt(1000))).ifTrue(new BytecodeBlock().append(createTempVariable.set(BytecodeExpressions.constantInt(0))).append(SqlRoutineCompiler.throwIfInterrupted())));
            LabelNode labelNode = new LabelNode("continue");
            LabelNode labelNode2 = new LabelNode("break");
            if (optional.isPresent()) {
                this.continueLabels.put(optional.get(), labelNode);
                this.breakLabels.put(optional.get(), labelNode2);
                bytecodeBlock.visitLabel(labelNode);
            }
            bytecodeBlock.append(function.apply(append));
            if (optional.isPresent()) {
                bytecodeBlock.visitLabel(labelNode2);
            }
            return bytecodeBlock;
        }

        private BytecodeNode compile(RowExpression rowExpression, Scope scope) {
            if (rowExpression instanceof InputReferenceExpression) {
                return scope.getVariable(SqlRoutineCompiler.name(((InputReferenceExpression) rowExpression).getField()));
            }
            return new BytecodeBlock().comment("boolean wasNull = false;").putVariable(scope.getVariable("wasNull"), rowExpression.getType().getJavaType() == Void.TYPE).comment("expression: " + rowExpression).append(new RowExpressionCompiler(this.cachedInstanceBinder.getCallSiteBinder(), this.cachedInstanceBinder, FieldReferenceCompiler.INSTANCE, SqlRoutineCompiler.this.functionManager, this.compiledLambdaMap).compile(rowExpression, scope)).append(BytecodeUtils.boxPrimitiveIfNecessary(scope, Primitives.wrap(rowExpression.getType().getJavaType())));
        }

        private BytecodeNode compileBoolean(RowExpression rowExpression, Scope scope) {
            Preconditions.checkArgument(rowExpression.getType().equals(BooleanType.BOOLEAN), "type must be boolean");
            LabelNode labelNode = new LabelNode("notNull");
            LabelNode labelNode2 = new LabelNode("done");
            return new BytecodeBlock().append(compile(rowExpression, scope)).comment("if value is null, return false, otherwise unbox").dup().ifNotNullGoto(labelNode).pop().push(false).gotoLabel(labelNode2).visitLabel(labelNode).invokeVirtual(Boolean.class, "booleanValue", Boolean.TYPE, new Class[0]).visitLabel(labelNode2);
        }

        private static BytecodeNode not(BytecodeNode bytecodeNode) {
            LabelNode labelNode = new LabelNode("true");
            LabelNode labelNode2 = new LabelNode("end");
            return new BytecodeBlock().append(bytecodeNode).comment("boolean not").ifTrueGoto(labelNode).push(true).gotoLabel(labelNode2).visitLabel(labelNode).push(false).visitLabel(labelNode2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/routine/SqlRoutineCompiler$FieldReferenceCompiler.class */
    public static class FieldReferenceCompiler implements RowExpressionVisitor<BytecodeNode, Scope> {
        public static final FieldReferenceCompiler INSTANCE = new FieldReferenceCompiler();

        private FieldReferenceCompiler() {
        }

        @Override // io.trino.sql.relational.RowExpressionVisitor
        public BytecodeNode visitInputReference(InputReferenceExpression inputReferenceExpression, Scope scope) {
            return new BytecodeBlock().append(scope.getVariable(SqlRoutineCompiler.name(inputReferenceExpression.getField()))).append(BytecodeUtils.unboxPrimitiveIfNecessary(scope, Primitives.wrap(inputReferenceExpression.getType().getJavaType())));
        }

        @Override // io.trino.sql.relational.RowExpressionVisitor
        public BytecodeNode visitCall(CallExpression callExpression, Scope scope) {
            throw new UnsupportedOperationException();
        }

        @Override // io.trino.sql.relational.RowExpressionVisitor
        public BytecodeNode visitSpecialForm(SpecialForm specialForm, Scope scope) {
            throw new UnsupportedOperationException();
        }

        @Override // io.trino.sql.relational.RowExpressionVisitor
        public BytecodeNode visitConstant(ConstantExpression constantExpression, Scope scope) {
            throw new UnsupportedOperationException();
        }

        @Override // io.trino.sql.relational.RowExpressionVisitor
        public BytecodeNode visitLambda(LambdaDefinitionExpression lambdaDefinitionExpression, Scope scope) {
            throw new UnsupportedOperationException();
        }

        @Override // io.trino.sql.relational.RowExpressionVisitor
        public BytecodeNode visitVariableReference(VariableReferenceExpression variableReferenceExpression, Scope scope) {
            throw new UnsupportedOperationException();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/routine/SqlRoutineCompiler$VariableExtractor.class */
    public static class VariableExtractor extends DefaultIrNodeVisitor {
        private final List<IrVariable> variables = new ArrayList();

        private VariableExtractor() {
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // io.trino.sql.routine.ir.DefaultIrNodeVisitor, io.trino.sql.routine.ir.IrNodeVisitor
        public Void visitVariable(IrVariable irVariable, Void r5) {
            this.variables.add(irVariable);
            return null;
        }

        public static List<IrVariable> extract(IrNode irNode) {
            VariableExtractor variableExtractor = new VariableExtractor();
            variableExtractor.process(irNode, null);
            return variableExtractor.variables;
        }
    }

    public SqlRoutineCompiler(FunctionManager functionManager) {
        this.functionManager = (FunctionManager) Objects.requireNonNull(functionManager, "functionManager is null");
    }

    public SpecializedSqlScalarFunction compile(IrRoutine irRoutine) {
        Type returnType = irRoutine.returnType();
        List list = (List) irRoutine.parameters().stream().map((v0) -> {
            return v0.type();
        }).collect(ImmutableList.toImmutableList());
        InvocationConvention invocationConvention = new InvocationConvention(Collections.nCopies(list.size(), InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE), InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN, true, true);
        Class<?> compileClass = compileClass(irRoutine);
        MethodHandle methodHandle = (MethodHandle) Arrays.stream(compileClass.getMethods()).filter(method -> {
            return method.getName().equals("run");
        }).map(Reflection::methodHandle).collect(MoreCollectors.onlyElement());
        MethodHandle constructorMethodHandle = Reflection.constructorMethodHandle(compileClass, (Class<?>[]) new Class[0]);
        MethodHandle asType = methodHandle.asType(methodHandle.type().changeParameterType(0, Object.class));
        MethodHandle asType2 = constructorMethodHandle.asType(constructorMethodHandle.type().changeReturnType(Object.class));
        return invocationConvention2 -> {
            return ScalarFunctionImplementation.builder().methodHandle(ScalarFunctionAdapter.adapt(asType, returnType, list, invocationConvention, invocationConvention2)).instanceFactory(asType2).build();
        };
    }

    @VisibleForTesting
    public Class<?> compileClass(IrRoutine irRoutine) {
        ClassDefinition classDefinition = new ClassDefinition(Access.a(new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("SqlRoutine"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
        generateRunMethod(classDefinition, cachedInstanceBinder, generateMethodsForLambda(classDefinition, cachedInstanceBinder, irRoutine), irRoutine);
        declareConstructor(classDefinition, cachedInstanceBinder);
        return CompilerUtils.defineClass(classDefinition, Object.class, callSiteBinder.getBindings(), new DynamicClassLoader(getClass().getClassLoader()));
    }

    private Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> generateMethodsForLambda(ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder, IrNode irNode) {
        Set<LambdaDefinitionExpression> extractLambda = extractLambda(irNode);
        ImmutableMap.Builder builder = ImmutableMap.builder();
        int i = 0;
        for (LambdaDefinitionExpression lambdaDefinitionExpression : extractLambda) {
            builder.put(lambdaDefinitionExpression, LambdaBytecodeGenerator.preGenerateLambdaExpression(lambdaDefinitionExpression, "lambda_" + i, classDefinition, builder.buildOrThrow(), cachedInstanceBinder.getCallSiteBinder(), cachedInstanceBinder, this.functionManager));
            i++;
        }
        return builder.buildOrThrow();
    }

    private void generateRunMethod(ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder, Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> map, IrRoutine irRoutine) {
        ImmutableList.Builder builder = ImmutableList.builder();
        builder.add(Parameter.arg("session", ConnectorSession.class));
        for (IrVariable irVariable : irRoutine.parameters()) {
            builder.add(Parameter.arg(name(irVariable), compilerType(irVariable.type())));
        }
        MethodDefinition declareMethod = classDefinition.declareMethod(Access.a(new Access[]{Access.PUBLIC}), "run", compilerType(irRoutine.returnType()), builder.build());
        Scope scope = declareMethod.getScope();
        scope.declareVariable(Boolean.TYPE, "wasNull");
        declareMethod.getBody().append(new BytecodeVisitor(cachedInstanceBinder, map, (Map) VariableExtractor.extract(irRoutine).stream().distinct().collect(ImmutableMap.toImmutableMap(Function.identity(), irVariable2 -> {
            return getOrDeclareVariable(scope, irVariable2);
        }))).process(irRoutine, scope));
    }

    private static BytecodeNode throwIfInterrupted() {
        return new IfStatement().condition(BytecodeExpressions.invokeStatic(Thread.class, "currentThread", Thread.class, new BytecodeExpression[0]).invoke("isInterrupted", Boolean.TYPE, new BytecodeExpression[0])).ifTrue(new BytecodeBlock().append(BytecodeExpressions.newInstance(RuntimeException.class, new BytecodeExpression[]{BytecodeExpressions.constantString("Thread interrupted")})).throwObject());
    }

    private static void declareConstructor(ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder) {
        MethodDefinition declareConstructor = classDefinition.declareConstructor(Access.a(new Access[]{Access.PUBLIC}), new Parameter[0]);
        BytecodeBlock body = declareConstructor.getBody();
        body.append(declareConstructor.getThis()).invokeConstructor(Object.class, new Class[0]);
        cachedInstanceBinder.generateInitializations(declareConstructor.getThis(), body);
        body.ret();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Variable getOrDeclareVariable(Scope scope, IrVariable irVariable) {
        return getOrDeclareVariable(scope, compilerType(irVariable.type()), name(irVariable));
    }

    private static Variable getOrDeclareVariable(Scope scope, ParameterizedType parameterizedType, String str) {
        try {
            return scope.getVariable(str);
        } catch (IllegalArgumentException e) {
            return scope.declareVariable(parameterizedType, str);
        }
    }

    private static ParameterizedType compilerType(Type type) {
        return ParameterizedType.type(Primitives.wrap(type.getJavaType()));
    }

    private static String name(IrVariable irVariable) {
        return name(irVariable.field());
    }

    private static String name(int i) {
        return "v" + i;
    }

    private static Set<LambdaDefinitionExpression> extractLambda(IrNode irNode) {
        final ImmutableSet.Builder builder = ImmutableSet.builder();
        irNode.accept(new DefaultIrNodeVisitor() { // from class: io.trino.sql.routine.SqlRoutineCompiler.1
            @Override // io.trino.sql.routine.ir.DefaultIrNodeVisitor
            public void visitRowExpression(RowExpression rowExpression) {
                builder.addAll(LambdaExpressionExtractor.extractLambdaExpressions(rowExpression));
            }
        }, null);
        return builder.build();
    }
}
