package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/IntermediateOperation.class */
public abstract class IntermediateOperation {
    public static final String FUNCTION_PREFIX = "imported_ml_function_";
    protected final String name;
    protected final String modelName;
    protected final List<IntermediateOperation> inputs;
    protected OrderedTensorType type;
    protected TensorFunction<Reference> function;
    protected final List<IntermediateOperation> outputs = new ArrayList();
    protected TensorFunction<Reference> rankingExpressionFunction = null;
    protected boolean exportAsRankingFunction = false;
    private boolean hasRenamedDimensions = false;
    private final List<String> importWarnings = new ArrayList();
    private Value constantValue = null;
    private List<IntermediateOperation> controlInputs = List.of();
    protected Function<OrderedTensorType, Value> constantValueFunction = null;

    /* loaded from: input_file:ai/vespa/rankingexpression/importer/operations/IntermediateOperation$AttributeMap.class */
    public interface AttributeMap {
        Optional<Value> get(String str);

        Optional<Value> get(String str, OrderedTensorType orderedTensorType);

        Optional<List<Value>> getList(String str);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public IntermediateOperation(String str, String str2, List<IntermediateOperation> list) {
        this.name = str2;
        this.modelName = ensureValidAsDimensionName(str);
        this.inputs = new ArrayList(list);
        this.inputs.forEach(intermediateOperation -> {
            intermediateOperation.outputs.add(this);
        });
    }

    protected abstract OrderedTensorType lazyGetType();

    protected abstract TensorFunction<Reference> lazyGetFunction();

    public String modelName() {
        return this.modelName;
    }

    public Optional<OrderedTensorType> type() {
        if (this.type == null) {
            this.type = lazyGetType();
        }
        return Optional.ofNullable(this.type);
    }

    public Optional<TensorFunction<Reference>> function() {
        if (this.function == null) {
            if (isConstant()) {
                this.function = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(Reference.simple("constant", vespaName())));
            } else if (this.outputs.size() > 1 || this.exportAsRankingFunction) {
                this.rankingExpressionFunction = lazyGetFunction();
                this.function = new VariableTensor(rankingExpressionFunctionName(), this.type.type());
            } else {
                this.function = lazyGetFunction();
            }
        }
        return Optional.ofNullable(this.function);
    }

    public String name() {
        return this.name;
    }

    public List<IntermediateOperation> inputs() {
        return this.inputs;
    }

    public List<IntermediateOperation> outputs() {
        return Collections.unmodifiableList(this.outputs);
    }

    public Optional<TensorFunction<Reference>> rankingExpressionFunction() {
        return Optional.ofNullable(this.rankingExpressionFunction);
    }

    public void addDimensionNameConstraints(DimensionRenamer dimensionRenamer) {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addConstraintsFrom(OrderedTensorType orderedTensorType, DimensionRenamer dimensionRenamer) {
        for (int i = 0; i < orderedTensorType.dimensions().size(); i++) {
            dimensionRenamer.addDimension(orderedTensorType.dimensions().get(i).name());
            for (int i2 = i + 1; i2 < orderedTensorType.dimensions().size(); i2++) {
                dimensionRenamer.addConstraint(orderedTensorType.dimensions().get(i).name(), orderedTensorType.dimensions().get(i2).name(), DimensionRenamer.Constraint.notEqual(false), this);
            }
        }
    }

    public void renameDimensions(DimensionRenamer dimensionRenamer) {
        this.type = this.type.rename(dimensionRenamer);
        this.hasRenamedDimensions = true;
    }

    public boolean isInput() {
        return false;
    }

    public boolean isConstant() {
        return this.inputs.stream().allMatch((v0) -> {
            return v0.isConstant();
        });
    }

    public void setConstantValue(Value value) {
        this.constantValue = value;
    }

    public Optional<Value> getConstantValue() {
        return this.constantValue != null ? Optional.of(this.constantValue) : this.constantValueFunction != null ? Optional.of(this.constantValueFunction.apply(type().orElse(null))) : Optional.empty();
    }

    public void setConstantValueFunction(Function<OrderedTensorType, Value> function) {
        this.constantValueFunction = function;
    }

    public void setControlInputs(List<IntermediateOperation> list) {
        this.controlInputs = list;
    }

    public List<IntermediateOperation> getControlInputs() {
        return Collections.unmodifiableList(this.controlInputs);
    }

    public String vespaName() {
        return isConstant() ? this.modelName + "_" + vespaName(this.name) : vespaName(this.name);
    }

    public static String vespaName(String str) {
        if (str != null) {
            return namePartOf(str).replace('/', '_').replace('.', '_');
        }
        return null;
    }

    public String rankingExpressionFunctionName() {
        String vespaName = vespaName();
        if (vespaName == null) {
            return null;
        }
        return isConstant() ? "constant(" + vespaName + ")" : "imported_ml_function_" + this.modelName + "_" + vespaName;
    }

    public List<String> warnings() {
        return Collections.unmodifiableList(this.importWarnings);
    }

    public void warning(String str) {
        this.importWarnings.add(str);
    }

    boolean verifyInputs(int i, Function<IntermediateOperation, Optional<?>> function) {
        if (this.inputs.size() != i) {
            throw new IllegalArgumentException("Expected " + i + " inputs for '" + this.name + "', got " + this.inputs.size());
        }
        return this.inputs.stream().map(function).allMatch((v0) -> {
            return v0.isPresent();
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean allInputTypesPresent(int i) {
        return verifyInputs(i, (v0) -> {
            return v0.type();
        });
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean allInputFunctionsPresent(int i) {
        return verifyInputs(i, (v0) -> {
            return v0.function();
        });
    }

    public Value evaluateAsConstant(OrderedTensorType orderedTensorType) {
        if (!isConstant()) {
            throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant.");
        }
        Value evaluateAsConstant = evaluableCopy().evaluateAsConstant((Context) new MapContext(DoubleValue.NaN));
        if (orderedTensorType == null) {
            return evaluateAsConstant;
        }
        Tensor asTensor = evaluateAsConstant.asTensor();
        checkIfRenameableTo(asTensor, orderedTensorType);
        setConstantValueFunction(orderedTensorType2 -> {
            return new TensorValue(asTensor.withType(orderedTensorType2.type()));
        });
        return new TensorValue(asTensor.withType(orderedTensorType.type()));
    }

    private void checkIfRenameableTo(Tensor tensor, OrderedTensorType orderedTensorType) {
        if (!tensor.type().isRenamableTo(orderedTensorType.type())) {
            throw new IllegalArgumentException("Constant evaluation in " + this.name + " resulted in wrong type. Expected: " + orderedTensorType.type() + " Got: " + tensor.type());
        }
    }

    private IntermediateOperation evaluableCopy() {
        if (this.hasRenamedDimensions) {
            return this;
        }
        IntermediateOperation copyTree = copyTree();
        IntermediateGraph intermediateGraph = new IntermediateGraph(this.modelName);
        intermediateGraph.put(this.name, copyTree);
        intermediateGraph.outputs(intermediateGraph.defaultSignature()).put(this.name, this.name);
        intermediateGraph.optimize();
        return copyTree;
    }

    private IntermediateOperation copyTree() {
        ArrayList arrayList = new ArrayList();
        if (this.constantValue != null) {
            Constant constant = new Constant(this.modelName, this.name, this.type);
            constant.setConstantValueFunction(orderedTensorType -> {
                return new TensorValue(this.constantValue.asTensor().withType(orderedTensorType.type()));
            });
            return constant;
        }
        this.inputs.forEach(intermediateOperation -> {
            arrayList.add(intermediateOperation.copyTree());
        });
        IntermediateOperation withInputs = withInputs(arrayList);
        if (this.constantValueFunction != null) {
            withInputs.constantValueFunction = this.constantValueFunction;
        }
        return withInputs;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Value evaluateAsConstant(Context context) {
        String str = "constant(" + vespaName() + ")";
        Value value = context.get(str);
        if (value == DoubleValue.NaN) {
            if (this.constantValue != null) {
                value = this.constantValue;
            } else if (!this.inputs.isEmpty()) {
                this.inputs.forEach(intermediateOperation -> {
                    intermediateOperation.evaluateAsConstant(context);
                });
                value = new TensorValue(lazyGetFunction().evaluate(context));
            } else {
                if (getConstantValue().isEmpty()) {
                    throw new IllegalArgumentException("Error in evaluating constant for " + this.name);
                }
                value = getConstantValue().get();
            }
            context.put(str, value);
            if (this.outputs.size() > 1 || this.exportAsRankingFunction) {
                context.put(rankingExpressionFunctionName(), value);
            }
        }
        return value;
    }

    public void insert(IntermediateOperation intermediateOperation, int i) {
        if (!intermediateOperation.inputs.isEmpty()) {
            throw new IllegalArgumentException("Operation to insert to '" + this.name + "' has existing inputs which is not supported.");
        }
        IntermediateOperation intermediateOperation2 = this.inputs.get(i);
        int findOutputNumber = findOutputNumber(intermediateOperation2, this);
        if (findOutputNumber == -1) {
            throw new IllegalArgumentException("Input '" + intermediateOperation2.name + "' to '" + this.name + "' does not have '" + this.name + "' as output.");
        }
        intermediateOperation2.outputs.set(findOutputNumber, intermediateOperation);
        intermediateOperation.inputs.add(intermediateOperation2);
        intermediateOperation.outputs.add(this);
        this.inputs.set(i, intermediateOperation);
    }

    public void uninsert(int i) {
        IntermediateOperation intermediateOperation = this.inputs.get(i);
        IntermediateOperation intermediateOperation2 = intermediateOperation.inputs.get(0);
        intermediateOperation2.outputs.set(findOutputNumber(intermediateOperation2, intermediateOperation), this);
        this.inputs.set(i, intermediateOperation2);
    }

    private int findOutputNumber(IntermediateOperation intermediateOperation, IntermediateOperation intermediateOperation2) {
        for (int i = 0; i < intermediateOperation.outputs.size(); i++) {
            if (intermediateOperation.outputs.get(i).equals(intermediateOperation2)) {
                return i;
            }
        }
        return -1;
    }

    public void removeDuplicateOutputsTo(IntermediateOperation intermediateOperation) {
        int lastIndexOf;
        int indexOf = this.outputs.indexOf(intermediateOperation);
        while (indexOf >= 0 && (lastIndexOf = this.outputs.lastIndexOf(intermediateOperation)) > indexOf) {
            this.outputs.remove(lastIndexOf);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TensorType.Value resultValueType() {
        return TensorType.Value.largestOf(this.inputs.stream().map(intermediateOperation -> {
            return intermediateOperation.type().get().type().valueType();
        }).toList());
    }

    public abstract IntermediateOperation withInputs(List<IntermediateOperation> list);

    String asString(Optional<OrderedTensorType> optional) {
        return (String) optional.map((v0) -> {
            return v0.toString();
        }).orElse("(unknown)");
    }

    public static String namePartOf(String str) {
        return (str.startsWith("^") ? str.substring(1) : str).split(":")[0];
    }

    public static int indexPartOf(String str) {
        int indexOf = str.indexOf(":");
        if (indexOf < 0) {
            return 0;
        }
        return Integer.parseInt(str.substring(indexOf + 1));
    }

    public abstract String operationName();

    private static String ensureValidAsDimensionName(String str) {
        return str.replaceAll("[^\\w\\d\\$@_]", "_");
    }

    public String toString() {
        return operationName() + "(" + ((String) inputs().stream().map(intermediateOperation -> {
            return asString(intermediateOperation.type());
        }).collect(Collectors.joining(", "))) + ")";
    }

    public String toFullString() {
        return "\t" + this.type + ":\t" + operationName() + "(" + ((String) inputs().stream().map((v0) -> {
            return v0.toFullString();
        }).collect(Collectors.joining(", "))) + ")";
    }
}
