package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/ConstantOfShape.class */
public class ConstantOfShape extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;
    private TensorType.Value valueTypeOfTensor;
    private double valueToFillWith;

    public ConstantOfShape(String str, String str2, List<IntermediateOperation> list, IntermediateOperation.AttributeMap attributeMap) {
        super(str, str2, list);
        this.valueTypeOfTensor = TensorType.Value.DOUBLE;
        this.valueToFillWith = 0.0d;
        this.attributeMap = attributeMap;
        Optional<Value> optional = attributeMap.get("value");
        if (optional.isPresent()) {
            Tensor asTensor = optional.get().asTensor();
            this.valueTypeOfTensor = asTensor.type().valueType();
            this.valueToFillWith = ((Double) asTensor.valueIterator().next()).doubleValue();
        }
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected OrderedTensorType lazyGetType() {
        if (!allInputTypesPresent(1)) {
            return null;
        }
        IntermediateOperation intermediateOperation = this.inputs.get(0);
        if (intermediateOperation.getConstantValue().isEmpty()) {
            throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a constant.");
        }
        Tensor asTensor = intermediateOperation.getConstantValue().get().asTensor();
        if (asTensor.type().dimensions().size() > 1) {
            throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a tensor with 0 or 1 dimensions.");
        }
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(this.valueTypeOfTensor);
        Iterator valueIterator = asTensor.valueIterator();
        int i = 0;
        while (valueIterator.hasNext()) {
            builder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, ((Double) valueIterator.next()).longValue()));
            i++;
        }
        return builder.build();
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    protected TensorFunction<Reference> lazyGetFunction() {
        if (!allInputTypesPresent(1)) {
            return null;
        }
        return Generate.bound(this.type.type(), TensorFunctionNode.wrapScalar(new ConstantNode(new DoubleValue(this.valueToFillWith))));
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
        addConstraintsFrom(this.type, dimensionRenamer);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public ConstantOfShape withInputs(List<IntermediateOperation> list) {
        return new ConstantOfShape(modelName(), name(), list, this.attributeMap);
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public String operationName() {
        return "ConstantOfShape";
    }

    @Override // ai.vespa.rankingexpression.importer.operations.IntermediateOperation
    public /* bridge */ /* synthetic */ IntermediateOperation withInputs(List list) {
        return withInputs((List<IntermediateOperation>) list);
    }
}
