package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.BindingExtractor;
import com.yahoo.lang.MutableInteger;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:ai/vespa/models/evaluation/LazyArrayContext.class */
public final class LazyArrayContext extends Context implements ContextIndex {
    private final ExpressionFunction function;
    private final IndexedBindings indexedBindings;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/vespa/models/evaluation/LazyArrayContext$IndexedBindings.class */
    public static class IndexedBindings {
        private final Map<String, Integer> nameToIndex;
        private final Set<String> arguments;
        private final Value[] values;
        private final Map<String, OnnxModel> onnxModels;
        private static final Value missing = new DoubleValue(Double.NaN).freeze();
        private Value missingValue = new DoubleValue(Double.NaN).freeze();

        private IndexedBindings(Map<String, Integer> map, Value[] valueArr, Set<String> set, Map<String, OnnxModel> map2) {
            this.nameToIndex = Map.copyOf(map);
            this.values = valueArr;
            this.arguments = set;
            this.onnxModels = Map.copyOf(map2);
        }

        IndexedBindings(ExpressionFunction expressionFunction, BindingExtractor bindingExtractor, Map<FunctionReference, ExpressionFunction> map, List<Constant> list, LazyArrayContext lazyArrayContext, Model model) {
            BindingExtractor.FunctionInfo extractFrom = bindingExtractor.extractFrom(expressionFunction);
            Set<String> set = extractFrom.bindTargets;
            this.onnxModels = Map.copyOf(extractFrom.onnxModelsInUse);
            this.arguments = Set.copyOf(extractFrom.arguments);
            this.values = new Value[set.size()];
            Arrays.fill(this.values, missing);
            MutableInteger mutableInteger = new MutableInteger(0);
            this.nameToIndex = Map.copyOf((Map) set.stream().collect(CustomCollectors.toLinkedMap(str -> {
                return str;
            }, str2 -> {
                return Integer.valueOf(mutableInteger.next());
            })));
            for (Constant constant : list) {
                Integer num = this.nameToIndex.get("constant(" + constant.name() + ")");
                if (num != null) {
                    this.values[num.intValue()] = new TensorValue(constant.value());
                }
            }
            for (FunctionReference functionReference : map.keySet()) {
                Integer num2 = this.nameToIndex.get(functionReference.serialForm());
                if (num2 != null) {
                    this.values[num2.intValue()] = new LazyValue(functionReference, lazyArrayContext, model);
                }
            }
        }

        private void setMissingValue(Tensor tensor) {
            this.missingValue = new TensorValue(tensor).freeze();
        }

        Value get(int i) {
            Value value = this.values[i];
            return value == missing ? this.missingValue : value;
        }

        void set(int i, Value value) {
            this.values[i] = value;
        }

        Set<String> names() {
            return this.nameToIndex.keySet();
        }

        Set<String> arguments() {
            return this.arguments;
        }

        Integer indexOf(String str) {
            return this.nameToIndex.get(str);
        }

        Map<String, OnnxModel> onnxModels() {
            return this.onnxModels;
        }

        IndexedBindings copy(Context context) {
            Value[] valueArr = new Value[this.values.length];
            for (int i = 0; i < this.values.length; i++) {
                valueArr[i] = this.values[i] instanceof LazyValue ? ((LazyValue) this.values[i]).copyFor(context) : this.values[i];
            }
            return new IndexedBindings(this.nameToIndex, valueArr, this.arguments, this.onnxModels);
        }
    }

    private LazyArrayContext(ExpressionFunction expressionFunction, IndexedBindings indexedBindings) {
        this.function = expressionFunction;
        this.indexedBindings = indexedBindings.copy(this);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LazyArrayContext(ExpressionFunction expressionFunction, BindingExtractor bindingExtractor, Map<FunctionReference, ExpressionFunction> map, List<Constant> list, Model model) {
        this.function = expressionFunction;
        this.indexedBindings = new IndexedBindings(expressionFunction, bindingExtractor, map, list, this, model);
    }

    public void setMissingValue(Tensor tensor) {
        this.indexedBindings.setMissingValue(tensor);
    }

    public void put(String str, Value value) {
        put(requireIndexOf(str).intValue(), value);
    }

    public final void put(int i, double d) {
        put(i, (Value) DoubleValue.frozen(d));
    }

    public void put(int i, Value value) {
        this.indexedBindings.set(i, value.freeze());
    }

    public TensorType getType(Reference reference) {
        return get(requireIndexOf(reference.toString()).intValue()).type();
    }

    public Value get(String str) {
        return get(requireIndexOf(str).intValue());
    }

    public Value get(int i) {
        return this.indexedBindings.get(i);
    }

    public double getDouble(int i) {
        return get(i).asDouble();
    }

    public int getIndex(String str) {
        return requireIndexOf(str).intValue();
    }

    public String resolveBinding(String str) {
        return null;
    }

    public int size() {
        return this.indexedBindings.names().size();
    }

    public Set<String> names() {
        return this.indexedBindings.names();
    }

    public Set<String> arguments() {
        return this.indexedBindings.arguments();
    }

    public Map<String, OnnxModel> onnxModels() {
        return this.indexedBindings.onnxModels();
    }

    private Integer requireIndexOf(String str) {
        Integer indexOf = this.indexedBindings.indexOf(str);
        if (indexOf == null) {
            throw new IllegalArgumentException("Value '" + str + "' can not be bound in " + this);
        }
        return indexOf;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean isMissing(String str) {
        return this.indexedBindings.indexOf(str) == null;
    }

    public Value defaultValue() {
        return this.indexedBindings.missingValue;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public LazyArrayContext copy() {
        return new LazyArrayContext(this.function, this.indexedBindings);
    }
}
