package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/ConcatReduce.class */
public class ConcatReduce extends IntermediateOperation {
    private static final String tmpDimensionName = "__concat_reduce_tmp_dimension_name__";
    private final Reduce.Aggregator aggregator;

    public ConcatReduce(String str, String str2, List<IntermediateOperation> list, Reduce.Aggregator aggregator) {
        super(str, str2, list);
        this.aggregator = aggregator;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (allInputTypesPresent(this.inputs.size())) {
            return this.inputs.get(0).type().get();
        }
        return null;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction<Reference> lazyGetFunction() {
        if (!allInputFunctionsPresent(this.inputs.size())) {
            return null;
        }
        TensorFunction tensorFunction = this.inputs.get(0).function().get();
        for (int i = 1; i < this.inputs.size(); i++) {
            tensorFunction = new Concat(tensorFunction, this.inputs.get(i).function().get(), tmpDimensionName);
        }
        return new com.yahoo.tensor.functions.Reduce(tensorFunction, this.aggregator, tmpDimensionName);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        if (allInputTypesPresent(this.inputs.size())) {
            OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
            for (int i = 1; i < this.inputs.size(); i++) {
                OrderedTensorType orderedTensorType2 = this.inputs.get(i).type().get();
                OrderedTensorType largestInput = largestInput(orderedTensorType, orderedTensorType2);
                OrderedTensorType smallestInput = smallestInput(orderedTensorType, orderedTensorType2);
                int rank = largestInput.rank() - smallestInput.rank();
                for (int i2 = 0; i2 < smallestInput.rank(); i2++) {
                    dimensionRenamer.addConstraint(largestInput.dimensions().get(i2 + rank).name(), smallestInput.dimensions().get(i2).name(), DimensionRenamer.Constraint.equal(false), this);
                }
                orderedTensorType = orderedTensorType2;
            }
        }
    }

    private OrderedTensorType largestInput(OrderedTensorType orderedTensorType, OrderedTensorType orderedTensorType2) {
        return orderedTensorType.rank() >= orderedTensorType2.rank() ? orderedTensorType : orderedTensorType2;
    }

    private OrderedTensorType smallestInput(OrderedTensorType orderedTensorType, OrderedTensorType orderedTensorType2) {
        return orderedTensorType.rank() < orderedTensorType2.rank() ? orderedTensorType : orderedTensorType2;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public ConcatReduce withInputs(List<IntermediateOperation> list) {
        return new ConcatReduce(modelName(), name(), list, this.aggregator);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "ConcatReduce";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
