package io.trino.operator.aggregation;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.trino.metadata.FunctionBinding;
import io.trino.metadata.SignatureBinder;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.ParametricFunctionHelpers;
import io.trino.operator.ParametricImplementationsGroup;
import io.trino.operator.aggregation.AggregationFromAnnotationsParser;
import io.trino.operator.aggregation.AggregationFunctionAdapter;
import io.trino.operator.annotations.ImplementationDependency;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionDependencyDeclaration;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.FunctionNullability;
import io.trino.spi.function.Signature;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.StringJoiner;

/* loaded from: input_file:io/trino/operator/aggregation/ParametricAggregation.class */
public class ParametricAggregation extends SqlAggregationFunction {
    private final ParametricImplementationsGroup<ParametricAggregationImplementation> implementations;
    private final List<AggregationFromAnnotationsParser.AccumulatorStateDetails<?>> stateDetails;

    public ParametricAggregation(Signature signature, AggregationHeader aggregationHeader, List<AggregationFromAnnotationsParser.AccumulatorStateDetails<?>> list, ParametricImplementationsGroup<ParametricAggregationImplementation> parametricImplementationsGroup) {
        super(createFunctionMetadata(signature, aggregationHeader, parametricImplementationsGroup.getFunctionNullability()), createAggregationFunctionMetadata(aggregationHeader, list));
        this.stateDetails = ImmutableList.copyOf((Collection) Objects.requireNonNull(list, "stateDetails is null"));
        Preconditions.checkArgument(parametricImplementationsGroup.getFunctionNullability().isReturnNullable(), "currently aggregates are required to be nullable");
        this.implementations = (ParametricImplementationsGroup) Objects.requireNonNull(parametricImplementationsGroup, "implementations is null");
    }

    private static FunctionMetadata createFunctionMetadata(Signature signature, AggregationHeader aggregationHeader, FunctionNullability functionNullability) {
        FunctionMetadata.Builder signature2 = FunctionMetadata.aggregateBuilder(aggregationHeader.getName()).signature(signature);
        Set<String> aliases = aggregationHeader.getAliases();
        Objects.requireNonNull(signature2);
        aliases.forEach(signature2::alias);
        if (aggregationHeader.getDescription().isPresent()) {
            signature2.description(aggregationHeader.getDescription().get());
        } else {
            signature2.noDescription();
        }
        if (aggregationHeader.isHidden()) {
            signature2.hidden();
        }
        if (aggregationHeader.isDeprecated()) {
            signature2.deprecated();
        }
        if (functionNullability.isReturnNullable()) {
            signature2.nullable();
        }
        signature2.argumentNullability(functionNullability.getArgumentNullable());
        return signature2.build();
    }

    private static AggregationFunctionMetadata createAggregationFunctionMetadata(AggregationHeader aggregationHeader, List<AggregationFromAnnotationsParser.AccumulatorStateDetails<?>> list) {
        AggregationFunctionMetadata.AggregationFunctionMetadataBuilder builder = AggregationFunctionMetadata.builder();
        if (aggregationHeader.isOrderSensitive()) {
            builder.orderSensitive();
        }
        if (aggregationHeader.isDecomposable()) {
            Iterator<AggregationFromAnnotationsParser.AccumulatorStateDetails<?>> it = list.iterator();
            while (it.hasNext()) {
                builder.intermediateType(it.next().getSerializedType());
            }
        }
        return builder.build();
    }

    @Override // io.trino.metadata.SqlFunction
    public FunctionDependencyDeclaration getFunctionDependencies() {
        FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder builder = FunctionDependencyDeclaration.builder();
        declareDependencies(builder, this.implementations.getExactImplementations().values());
        declareDependencies(builder, this.implementations.getSpecializedImplementations());
        declareDependencies(builder, this.implementations.getGenericImplementations());
        Iterator<AggregationFromAnnotationsParser.AccumulatorStateDetails<?>> it = this.stateDetails.iterator();
        while (it.hasNext()) {
            Iterator<ImplementationDependency> it2 = it.next().getDependencies().iterator();
            while (it2.hasNext()) {
                it2.next().declareDependencies(builder);
            }
        }
        return builder.build();
    }

    private static void declareDependencies(FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder functionDependencyDeclarationBuilder, Collection<ParametricAggregationImplementation> collection) {
        for (ParametricAggregationImplementation parametricAggregationImplementation : collection) {
            Iterator<ImplementationDependency> it = parametricAggregationImplementation.getInputDependencies().iterator();
            while (it.hasNext()) {
                it.next().declareDependencies(functionDependencyDeclarationBuilder);
            }
            Iterator<ImplementationDependency> it2 = parametricAggregationImplementation.getCombineDependencies().iterator();
            while (it2.hasNext()) {
                it2.next().declareDependencies(functionDependencyDeclarationBuilder);
            }
            Iterator<ImplementationDependency> it3 = parametricAggregationImplementation.getOutputDependencies().iterator();
            while (it3.hasNext()) {
                it3.next().declareDependencies(functionDependencyDeclarationBuilder);
            }
        }
    }

    @Override // io.trino.metadata.SqlAggregationFunction
    public AggregationImplementation specialize(BoundSignature boundSignature, FunctionDependencies functionDependencies) {
        ParametricAggregationImplementation findMatchingImplementation = findMatchingImplementation(boundSignature);
        List<AggregationFunctionAdapter.AggregationParameterKind> inputParameterKinds = findMatchingImplementation.getInputParameterKinds();
        AggregationImplementation.Builder builder = AggregationImplementation.builder();
        FunctionMetadata functionMetadata = getFunctionMetadata();
        FunctionBinding bindFunction = SignatureBinder.bindFunction(functionMetadata.getFunctionId(), functionMetadata.getSignature(), boundSignature);
        builder.accumulatorStateDescriptors((List) this.stateDetails.stream().map(accumulatorStateDetails -> {
            return accumulatorStateDetails.createAccumulatorStateDescriptor(bindFunction, functionDependencies);
        }).collect(ImmutableList.toImmutableList()));
        builder.inputFunction(AggregationFunctionAdapter.normalizeInputMethod(ParametricFunctionHelpers.bindDependencies(findMatchingImplementation.getInputFunction(), findMatchingImplementation.getInputDependencies(), bindFunction, functionDependencies), boundSignature, inputParameterKinds));
        Optional map = findMatchingImplementation.getRemoveInputFunction().map(methodHandle -> {
            return ParametricFunctionHelpers.bindDependencies(methodHandle, findMatchingImplementation.getRemoveInputDependencies(), bindFunction, functionDependencies);
        }).map(methodHandle2 -> {
            return AggregationFunctionAdapter.normalizeInputMethod(methodHandle2, boundSignature, (List<AggregationFunctionAdapter.AggregationParameterKind>) inputParameterKinds);
        });
        Objects.requireNonNull(builder);
        map.ifPresent(builder::removeInputFunction);
        if (getAggregationMetadata().isDecomposable()) {
            builder.combineFunction(ParametricFunctionHelpers.bindDependencies(findMatchingImplementation.getCombineFunction().orElseThrow(() -> {
                return new IllegalArgumentException(String.format("Decomposable method %s does not have a combine method", boundSignature.getName()));
            }), findMatchingImplementation.getCombineDependencies(), bindFunction, functionDependencies));
        } else {
            Preconditions.checkArgument(findMatchingImplementation.getCombineFunction().isEmpty(), "Decomposable method %s does not have a combine method", boundSignature.getName());
        }
        builder.outputFunction(ParametricFunctionHelpers.bindDependencies(findMatchingImplementation.getOutputFunction(), findMatchingImplementation.getOutputDependencies(), bindFunction, functionDependencies));
        return builder.build();
    }

    @VisibleForTesting
    public List<AggregationFromAnnotationsParser.AccumulatorStateDetails<?>> getStateDetails() {
        return this.stateDetails;
    }

    @VisibleForTesting
    public ParametricImplementationsGroup<ParametricAggregationImplementation> getImplementations() {
        return this.implementations;
    }

    private ParametricAggregationImplementation findMatchingImplementation(BoundSignature boundSignature) {
        Signature signature = boundSignature.toSignature();
        Optional empty = Optional.empty();
        if (this.implementations.getExactImplementations().containsKey(signature)) {
            empty = Optional.of(this.implementations.getExactImplementations().get(signature));
        } else {
            for (ParametricAggregationImplementation parametricAggregationImplementation : this.implementations.getGenericImplementations()) {
                if (parametricAggregationImplementation.areTypesAssignable(boundSignature)) {
                    if (empty.isPresent()) {
                        throw new TrinoException(StandardErrorCode.AMBIGUOUS_FUNCTION_CALL, String.format("Ambiguous function call (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
                    }
                    empty = Optional.of(parametricAggregationImplementation);
                }
            }
        }
        if (empty.isEmpty()) {
            throw new TrinoException(StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING, String.format("Unsupported type parameters (%s) for %s", boundSignature, getFunctionMetadata().getSignature()));
        }
        return (ParametricAggregationImplementation) empty.get();
    }

    public String toString() {
        return new StringJoiner(", ", ParametricAggregation.class.getSimpleName() + "[", "]").add("signature=" + this.implementations.getSignature()).toString();
    }
}
