package com.yahoo.searchlib.rankingexpression.rule;

import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.class */
public class TensorFunctionNode extends CompositeNode {
    private final TensorFunction<Reference> function;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode$ContextWrapper.class */
    public static class ContextWrapper extends Context {
        private final EvaluationContext<Reference> delegate;

        public ContextWrapper(EvaluationContext<Reference> evaluationContext) {
            this.delegate = evaluationContext;
        }

        @Override // com.yahoo.searchlib.rankingexpression.evaluation.Context
        public Value get(String str) {
            return new TensorValue(this.delegate.getTensor(str));
        }

        public TensorType getType(Reference reference) {
            return this.delegate.getType(reference);
        }

        public String resolveBinding(String str) {
            return this.delegate.resolveBinding(str);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode$ExpressionScalarFunction.class */
    public static class ExpressionScalarFunction implements ScalarFunction<Reference> {
        private final ExpressionNode expression;

        public ExpressionScalarFunction(ExpressionNode expressionNode) {
            this.expression = expressionNode;
        }

        public Double apply(EvaluationContext<Reference> evaluationContext) {
            return Double.valueOf(this.expression.evaluate(TensorFunctionNode.asContext(evaluationContext)).asDouble());
        }

        public Optional<TensorFunction<Reference>> asTensorFunction() {
            return Optional.of(new ExpressionTensorFunction(this.expression));
        }

        public String toString() {
            return toString(ExpressionToStringContext.empty);
        }

        public String toString(ToStringContext<Reference> toStringContext) {
            ToStringContext<Reference> toStringContext2;
            ToStringContext<Reference> toStringContext3 = toStringContext;
            while (true) {
                toStringContext2 = toStringContext3;
                if (toStringContext2.parent() == null) {
                    break;
                }
                toStringContext3 = toStringContext2.parent();
            }
            if (!(toStringContext2 instanceof ExpressionToStringContext)) {
                return this.expression.toString();
            }
            ExpressionToStringContext expressionToStringContext = (ExpressionToStringContext) toStringContext2;
            ExpressionNode expressionNode = this.expression;
            if ((expressionNode instanceof CompositeNode) && !(expressionNode instanceof EmbracedNode) && !isIdentifierReference(expressionNode)) {
                expressionNode = new EmbracedNode(expressionNode);
            }
            return expressionNode.toString(new StringBuilder(), new ExpressionToStringContext(expressionToStringContext.wrappedSerializationContext, toStringContext, expressionToStringContext.path, expressionToStringContext.parent), expressionToStringContext.path, expressionToStringContext.parent).toString();
        }

        private boolean isIdentifierReference(ExpressionNode expressionNode) {
            if (expressionNode instanceof ReferenceNode) {
                return ((ReferenceNode) expressionNode).reference().isIdentifier();
            }
            return false;
        }
    }

    /* loaded from: input_file:com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode$ExpressionTensorFunction.class */
    public static class ExpressionTensorFunction extends PrimitiveTensorFunction<Reference> {
        private final ExpressionNode expression;

        public ExpressionTensorFunction(ExpressionNode expressionNode) {
            this.expression = expressionNode;
        }

        public ExpressionNode wrappedExpression() {
            return this.expression;
        }

        public List<TensorFunction<Reference>> arguments() {
            ExpressionNode expressionNode = this.expression;
            return expressionNode instanceof CompositeNode ? (List) ((CompositeNode) expressionNode).children().stream().map(ExpressionTensorFunction::new).collect(Collectors.toList()) : List.of();
        }

        public TensorFunction<Reference> withArguments(List<TensorFunction<Reference>> list) {
            if (list.size() == 0) {
                return this;
            }
            return new ExpressionTensorFunction(((CompositeNode) this.expression).setChildren((List) list.stream().map(tensorFunction -> {
                return ((ExpressionTensorFunction) tensorFunction).expression;
            }).collect(Collectors.toList())));
        }

        public PrimitiveTensorFunction<Reference> toPrimitive() {
            return this;
        }

        public TensorType type(TypeContext<Reference> typeContext) {
            return this.expression.type(typeContext);
        }

        public Optional<ScalarFunction<Reference>> asScalarFunction() {
            return Optional.of(new ExpressionScalarFunction(this.expression));
        }

        public Tensor evaluate(EvaluationContext<Reference> evaluationContext) {
            return this.expression.evaluate(TensorFunctionNode.asContext(evaluationContext)).asTensor();
        }

        public String toString() {
            return toString(ExpressionToStringContext.empty);
        }

        public int hashCode() {
            return this.expression.hashCode();
        }

        public String toString(ToStringContext<Reference> toStringContext) {
            ToStringContext<Reference> toStringContext2;
            ToStringContext<Reference> toStringContext3 = toStringContext;
            while (true) {
                toStringContext2 = toStringContext3;
                if (toStringContext2.parent() == null) {
                    break;
                }
                toStringContext3 = toStringContext2.parent();
            }
            if (!(toStringContext2 instanceof ExpressionToStringContext)) {
                return this.expression.toString();
            }
            ExpressionToStringContext expressionToStringContext = (ExpressionToStringContext) toStringContext2;
            return this.expression.toString(new StringBuilder(), new ExpressionToStringContext(expressionToStringContext.wrappedSerializationContext, toStringContext, expressionToStringContext.path, expressionToStringContext.parent), expressionToStringContext.path, expressionToStringContext.parent).toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode$ExpressionToStringContext.class */
    public static class ExpressionToStringContext extends SerializationContext implements ToStringContext<Reference> {
        private final ToStringContext<Reference> wrappedToStringContext;
        private final SerializationContext wrappedSerializationContext;
        private final Deque<String> path;
        private final CompositeNode parent;
        public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), null, null);

        ExpressionToStringContext(SerializationContext serializationContext, Deque<String> deque, CompositeNode compositeNode) {
            this(serializationContext, null, deque, compositeNode);
        }

        ExpressionToStringContext(SerializationContext serializationContext, ToStringContext<Reference> toStringContext, Deque<String> deque, CompositeNode compositeNode) {
            this.wrappedSerializationContext = serializationContext;
            this.wrappedToStringContext = toStringContext;
            this.path = deque;
            this.parent = compositeNode;
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext
        public void addFunctionSerialization(String str, String str2) {
            this.wrappedSerializationContext.addFunctionSerialization(str, str2);
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext
        public void addArgumentTypeSerialization(String str, String str2, TensorType tensorType) {
            this.wrappedSerializationContext.addArgumentTypeSerialization(str, str2, tensorType);
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext
        public void addFunctionTypeSerialization(String str, TensorType tensorType) {
            this.wrappedSerializationContext.addFunctionTypeSerialization(str, tensorType);
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext
        public Map<String, String> serializedFunctions() {
            return this.wrappedSerializationContext.serializedFunctions();
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext
        public ExpressionFunction getFunction(String str) {
            return this.wrappedSerializationContext.getFunction(str);
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext
        public Optional<TypeContext<Reference>> typeContext() {
            return this.wrappedSerializationContext.typeContext();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext
        public Map<String, ExpressionFunction> getFunctions() {
            return this.wrappedSerializationContext.getFunctions();
        }

        public ToStringContext<Reference> parent() {
            return this.wrappedToStringContext;
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext
        public String getBinding(String str) {
            return (this.wrappedToStringContext == null || this.wrappedToStringContext.getBinding(str) == null) ? this.wrappedSerializationContext.getBinding(str) : this.wrappedToStringContext.getBinding(str);
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext
        public ExpressionToStringContext withBindings(Map<String, String> map) {
            return new ExpressionToStringContext(new SerializationContext(getFunctions(), map, typeContext(), serializedFunctions()), this.wrappedToStringContext, this.path, this.parent);
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext
        public SerializationContext withoutBindings() {
            return new ExpressionToStringContext(new SerializationContext(getFunctions(), (Map<String, String>) null, typeContext(), serializedFunctions()), null, this.path, this.parent);
        }

        public String toString() {
            return "TensorFunctionNode.ExpressionToStringContext with wrapped serialization context: " + String.valueOf(this.wrappedSerializationContext);
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext
        public /* bridge */ /* synthetic */ SerializationContext withBindings(Map map) {
            return withBindings((Map<String, String>) map);
        }

        @Override // com.yahoo.searchlib.rankingexpression.rule.SerializationContext, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext
        public /* bridge */ /* synthetic */ FunctionReferenceContext withBindings(Map map) {
            return withBindings((Map<String, String>) map);
        }
    }

    public TensorFunctionNode(TensorFunction<Reference> tensorFunction) {
        this.function = tensorFunction;
    }

    public TensorFunction<Reference> function() {
        return this.function;
    }

    @Override // com.yahoo.searchlib.rankingexpression.rule.CompositeNode
    public List<ExpressionNode> children() {
        return (List) this.function.arguments().stream().map(this::toExpressionNode).collect(Collectors.toList());
    }

    private ExpressionNode toExpressionNode(TensorFunction<Reference> tensorFunction) {
        return tensorFunction instanceof ExpressionTensorFunction ? ((ExpressionTensorFunction) tensorFunction).expression : new TensorFunctionNode(tensorFunction);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static ScalarFunction<Reference> transform(ScalarFunction<Reference> scalarFunction, java.util.function.Function<ExpressionNode, ExpressionNode> function) {
        return scalarFunction instanceof ExpressionScalarFunction ? new ExpressionScalarFunction(function.apply(((ExpressionScalarFunction) scalarFunction).expression)) : scalarFunction;
    }

    public ExpressionNode withTransformedExpressions(java.util.function.Function<ExpressionNode, ExpressionNode> function) {
        ExpressionTensorFunction expressionTensorFunction = this.function;
        return expressionTensorFunction instanceof ExpressionTensorFunction ? function.apply(expressionTensorFunction.expression) : new TensorFunctionNode(this.function.withTransformedFunctions(scalarFunction -> {
            return transform(scalarFunction, function);
        }));
    }

    @Override // com.yahoo.searchlib.rankingexpression.rule.CompositeNode
    public CompositeNode setChildren(List<ExpressionNode> list) {
        return new TensorFunctionNode(this.function.withArguments((List) list.stream().map(ExpressionTensorFunction::new).collect(Collectors.toList())));
    }

    @Override // com.yahoo.searchlib.rankingexpression.rule.ExpressionNode
    public StringBuilder toString(StringBuilder sb, SerializationContext serializationContext, Deque<String> deque, CompositeNode compositeNode) {
        return sb.append(this.function.toPrimitive().toString(new ExpressionToStringContext(serializationContext, deque, this)));
    }

    @Override // com.yahoo.searchlib.rankingexpression.rule.ExpressionNode
    public TensorType type(TypeContext<Reference> typeContext) {
        return this.function.type(typeContext);
    }

    @Override // com.yahoo.searchlib.rankingexpression.rule.ExpressionNode
    public Value evaluate(Context context) {
        return new TensorValue(this.function.evaluate(context));
    }

    public static ExpressionTensorFunction wrap(ExpressionNode expressionNode) {
        return new ExpressionTensorFunction(expressionNode);
    }

    public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> map) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<TensorAddress, ExpressionNode> entry : map.entrySet()) {
            linkedHashMap.put(entry.getKey(), wrapScalar(entry.getValue()));
        }
        return linkedHashMap;
    }

    public static void wrapScalarBlock(TensorType tensorType, List<String> list, String str, List<ExpressionNode> list2, Map<TensorAddress, ScalarFunction<Reference>> map) {
        TensorType tensorType2 = new TensorType(tensorType.valueType(), (Collection) tensorType.dimensions().stream().filter((v0) -> {
            return v0.isIndexed();
        }).collect(Collectors.toList()));
        ArrayList arrayList = new ArrayList(list);
        arrayList.retainAll(tensorType2.dimensionNames());
        IndexedTensor.Indexes of = IndexedTensor.Indexes.of(tensorType2, arrayList);
        if (of.size() != list2.size()) {
            long size = of.size();
            String valueOf = String.valueOf(tensorType);
            list2.size();
            IllegalArgumentException illegalArgumentException = new IllegalArgumentException("At '" + str + "': Need " + size + " values to fill a dense subspace of " + illegalArgumentException + " but got " + valueOf);
            throw illegalArgumentException;
        }
        for (ExpressionNode expressionNode : list2) {
            of.next();
            String[] strArr = new String[tensorType.rank()];
            int i = 0;
            int i2 = 0;
            Iterator it = tensorType.dimensions().iterator();
            while (it.hasNext()) {
                if (((TensorType.Dimension) it.next()).isIndexed()) {
                    int i3 = i2;
                    i2++;
                    int i4 = i;
                    i++;
                    strArr[i3] = String.valueOf(of.indexesForReading()[i4]);
                } else {
                    int i5 = i2;
                    i2++;
                    strArr[i5] = str;
                }
            }
            map.put(TensorAddress.of(strArr), wrapScalar(expressionNode));
        }
    }

    public static List<ScalarFunction<Reference>> wrapScalars(TensorType tensorType, List<String> list, List<ExpressionNode> list2) {
        IndexedTensor.Indexes of = IndexedTensor.Indexes.of(tensorType, list);
        if (of.size() != list2.size()) {
            long size = of.size();
            String valueOf = String.valueOf(tensorType);
            list2.size();
            IllegalArgumentException illegalArgumentException = new IllegalArgumentException("Need " + size + " values to fill " + illegalArgumentException + " but got " + valueOf);
            throw illegalArgumentException;
        }
        ArrayList arrayList = new ArrayList(list2.size());
        for (int i = 0; i < list2.size(); i++) {
            arrayList.add(null);
        }
        int i2 = 0;
        while (of.hasNext()) {
            of.next();
            int i3 = i2;
            i2++;
            arrayList.set((int) of.toSourceValueIndex(), wrapScalar(list2.get(i3)));
        }
        return arrayList;
    }

    public static ScalarFunction<Reference> wrapScalar(ExpressionNode expressionNode) {
        return new ExpressionScalarFunction(expressionNode);
    }

    @Override // com.yahoo.searchlib.rankingexpression.rule.ExpressionNode
    public int hashCode() {
        return this.function.hashCode();
    }

    private static Context asContext(EvaluationContext<Reference> evaluationContext) {
        return evaluationContext instanceof Context ? (Context) evaluationContext : new ContextWrapper(evaluationContext);
    }
}
