package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/OnnxConcat.class */
public class OnnxConcat extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private String concatDimensionName;
    private int concatDimensionIndex;

    public OnnxConcat(String str, String str2, List<IntermediateOperation> list, IntermediateOperation.AttributeMap attributeMap) {
        super(str, str2, list);
        this.attributeMap = attributeMap;
        if (attributeMap.get("axis").isEmpty()) {
            throw new IllegalArgumentException("OnnxConcat in " + this.name + ": Required attribute 'axis' is missing.");
        }
        this.concatDimensionIndex = (int) attributeMap.get("axis").get().asDouble();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!this.inputs.stream().map((v0) -> {
            return v0.type();
        }).allMatch((v0) -> {
            return v0.isPresent();
        })) {
            return null;
        }
        OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
        if (this.concatDimensionIndex < 0) {
            this.concatDimensionIndex = orderedTensorType.dimensions().size() + this.concatDimensionIndex;
        }
        long longValue = ((Long) orderedTensorType.dimensions().get(this.concatDimensionIndex).size().orElse(-1L)).longValue();
        for (int i = 1; i < this.inputs.size(); i++) {
            OrderedTensorType orderedTensorType2 = this.inputs.get(i).type().get();
            if (orderedTensorType2.rank() != orderedTensorType.rank()) {
                throw new IllegalArgumentException("OnnxConcat in " + this.name + ": Inputs must have the same rank.");
            }
            for (int i2 = 0; i2 < orderedTensorType.rank(); i2++) {
                long longValue2 = ((Long) orderedTensorType.dimensions().get(i2).size().orElse(-1L)).longValue();
                long longValue3 = ((Long) orderedTensorType2.dimensions().get(i2).size().orElse(-1L)).longValue();
                if (i2 == this.concatDimensionIndex) {
                    longValue += longValue3;
                } else if (longValue2 != longValue3) {
                    throw new IllegalArgumentException("OnnxConcat in " + this.name + ": input dimension " + i2 + " differs in input tensors.");
                }
            }
        }
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        int i3 = 0;
        for (TensorType.Dimension dimension : orderedTensorType.dimensions()) {
            if (i3 == this.concatDimensionIndex) {
                this.concatDimensionName = dimension.name();
                builder.add(TensorType.Dimension.indexed(this.concatDimensionName, longValue));
            } else {
                builder.add(dimension);
            }
            i3++;
        }
        return builder.build();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction<Reference> lazyGetFunction() {
        if (!this.inputs.stream().map((v0) -> {
            return v0.function();
        }).allMatch((v0) -> {
            return v0.isPresent();
        })) {
            return null;
        }
        TensorFunction tensorFunction = this.inputs.get(0).function().get();
        for (int i = 1; i < this.inputs.size(); i++) {
            tensorFunction = new Concat(tensorFunction, this.inputs.get(i).function().get(), this.concatDimensionName);
        }
        return tensorFunction;
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        if (this.inputs.stream().map((v0) -> {
            return v0.type();
        }).allMatch((v0) -> {
            return v0.isPresent();
        })) {
            OrderedTensorType orderedTensorType = this.inputs.get(0).type().get();
            for (int i = 1; i < this.inputs.size(); i++) {
                dimensionRenamer.addConstraint(orderedTensorType.dimensions().get(this.concatDimensionIndex).name(), this.inputs.get(i).type().get().dimensions().get(this.concatDimensionIndex).name(), DimensionRenamer.Constraint.equal(false), this);
            }
        }
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void renameDimensions(DimensionRenamer dimensionRenamer) {
        super.renameDimensions(dimensionRenamer);
        this.concatDimensionName = dimensionRenamer.dimensionNameOf(this.concatDimensionName).orElse(this.concatDimensionName);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public OnnxConcat withInputs(List<IntermediateOperation> list) {
        return new OnnxConcat(modelName(), name(), list, this.attributeMap);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "ConcatV2";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
