package com.yahoo.schema.expressiontransforms;

import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Slice;
import java.io.StringReader;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:com/yahoo/schema/expressiontransforms/TokenTransformer.class */
public class TokenTransformer extends ExpressionTransformer<RankProfileTransformContext> {
    private static final ConstantNode ZERO = new ConstantNode(new DoubleValue(0.0d));
    private static final ConstantNode ONE = new ConstantNode(new DoubleValue(1.0d));
    private static final ConstantNode TWO = new ConstantNode(new DoubleValue(2.0d));
    private static final ConstantNode CLS = new ConstantNode(new DoubleValue(101.0d));
    private static final ConstantNode SEP = new ConstantNode(new DoubleValue(102.0d));
    private static final ExpressionFunction commonLengthFunction = makeLengthFunction();

    public ExpressionNode transform(ExpressionNode expressionNode, RankProfileTransformContext rankProfileTransformContext) {
        return expressionNode instanceof ReferenceNode ? transformFeature((ReferenceNode) expressionNode, rankProfileTransformContext) : expressionNode instanceof CompositeNode ? super.transformChildren((CompositeNode) expressionNode, rankProfileTransformContext) : expressionNode;
    }

    private ExpressionNode transformFeature(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext) {
        return (referenceNode.getName().equals("customTokenInputIds") && shouldTransform(referenceNode, rankProfileTransformContext)) ? transformCustomTokenInputIds(referenceNode, rankProfileTransformContext) : (referenceNode.getName().equals("tokenInputIds") && shouldTransform(referenceNode, rankProfileTransformContext)) ? transformTokenInputIds(referenceNode, rankProfileTransformContext) : (referenceNode.getName().equals("tokenTypeIds") && shouldTransform(referenceNode, rankProfileTransformContext)) ? transformTokenTypeIds(referenceNode, rankProfileTransformContext) : (referenceNode.getName().equals("tokenAttentionMask") && shouldTransform(referenceNode, rankProfileTransformContext)) ? transformTokenAttentionMask(referenceNode, rankProfileTransformContext) : referenceNode;
    }

    private ExpressionNode transformTokenInputIds(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext) {
        return transformTokenInputIds(referenceNode, rankProfileTransformContext, CLS, SEP, 1);
    }

    private ExpressionNode transformCustomTokenInputIds(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext) {
        return transformTokenInputIds(referenceNode, rankProfileTransformContext, (ExpressionNode) referenceNode.getArguments().expressions().get(0), (ExpressionNode) referenceNode.getArguments().expressions().get(1), 3);
    }

    private ExpressionNode transformTokenInputIds(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext, ExpressionNode expressionNode, ExpressionNode expressionNode2, int i) {
        checkReferenceArguments(referenceNode, i);
        TensorType createTensorType = createTensorType(referenceNode.getName(), (ExpressionNode) referenceNode.getArguments().expressions().get(i - 1));
        createTokenLengthFunctions(referenceNode, rankProfileTransformContext, i);
        return new TensorFunctionNode(Generate.bound(createTensorType, TensorFunctionNode.wrapScalar(createTokenSequenceExpr(0, createTokenSequence(referenceNode, expressionNode, expressionNode2, i)))));
    }

    private ExpressionNode transformTokenTypeIds(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext) {
        checkReferenceArguments(referenceNode, 1);
        TensorType createTensorType = createTensorType(referenceNode.getName(), (ExpressionNode) referenceNode.getArguments().expressions().get(0));
        createTokenLengthFunctions(referenceNode, rankProfileTransformContext, 1);
        List<ExpressionNode> createTokenSequence = createTokenSequence(referenceNode, CLS, SEP, 1);
        return new TensorFunctionNode(Generate.bound(createTensorType, TensorFunctionNode.wrapScalar(new IfNode(new OperationNode(new ReferenceNode("d1"), Operator.smaller, createLengthExpr(2, createTokenSequence)), ZERO, new IfNode(new OperationNode(new ReferenceNode("d1"), Operator.smaller, createLengthExpr(createTokenSequence.size() - 1, createTokenSequence)), ONE, ZERO)))));
    }

    private ExpressionNode transformTokenAttentionMask(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext) {
        checkReferenceArguments(referenceNode, 1);
        TensorType createTensorType = createTensorType(referenceNode.getName(), (ExpressionNode) referenceNode.getArguments().expressions().get(0));
        createTokenLengthFunctions(referenceNode, rankProfileTransformContext, 1);
        List<ExpressionNode> createTokenSequence = createTokenSequence(referenceNode, CLS, SEP, 1);
        return new TensorFunctionNode(Generate.bound(createTensorType, TensorFunctionNode.wrapScalar(new IfNode(new OperationNode(new ReferenceNode("d1"), Operator.smaller, createLengthExpr(createTokenSequence.size() - 1, createTokenSequence)), ONE, ZERO))));
    }

    private boolean shouldTransform(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext) {
        return !rankProfileTransformContext.rankProfile().getFunctions().containsKey(referenceNode.getName()) && referenceNode.getArguments().size() >= 2;
    }

    private void checkReferenceArguments(ReferenceNode referenceNode, int i) {
        for (int i2 = i; i2 < referenceNode.getArguments().size(); i2++) {
            ExpressionNode expressionNode = (ExpressionNode) referenceNode.getArguments().expressions().get(i2);
            if (!(expressionNode instanceof ReferenceNode)) {
                throw new IllegalArgumentException("Invalid argument " + i2 + " to " + referenceNode.getName() + ": the argument must be a reference. Got " + expressionNode.toString());
            }
        }
    }

    public static TensorType createTensorType(String str, ExpressionNode expressionNode) {
        try {
            return new TensorType.Builder(TensorType.Value.FLOAT).indexed("d0", 1L).indexed("d1", Integer.parseInt(expressionNode.toString())).build();
        } catch (NumberFormatException e) {
            throw new IllegalArgumentException("Invalid argument to " + str + ": the first argument must be the length to the token sequence to generate. Got " + expressionNode);
        }
    }

    private static ExpressionFunction makeLengthFunction() {
        try {
            StringReader stringReader = new StringReader("sum(map(input, f(x)(x > 0)))");
            try {
                ExpressionFunction expressionFunction = new ExpressionFunction("__token_length", List.of("input"), new RankingExpression("__token_length", stringReader));
                stringReader.close();
                return expressionFunction;
            } finally {
            }
        } catch (ParseException e) {
            throw new IllegalStateException("unexpected", e);
        }
    }

    private ExpressionFunction.Instance lengthFunctionFor(ReferenceNode referenceNode) {
        return commonLengthFunction.expand(new SerializationContext(), List.of(referenceNode), new ArrayDeque());
    }

    private List<ExpressionNode> createTokenSequence(ReferenceNode referenceNode, ExpressionNode expressionNode, ExpressionNode expressionNode2, int i) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(expressionNode);
        for (int i2 = i; i2 < referenceNode.getArguments().size(); i2++) {
            arrayList.add((ExpressionNode) referenceNode.getArguments().expressions().get(i2));
            arrayList.add(expressionNode2);
        }
        return arrayList;
    }

    private void createTokenLengthFunctions(ReferenceNode referenceNode, RankProfileTransformContext rankProfileTransformContext, int i) {
        for (int i2 = i; i2 < referenceNode.getArguments().size(); i2++) {
            ExpressionNode expressionNode = (ExpressionNode) referenceNode.getArguments().expressions().get(i2);
            if (!(expressionNode instanceof ReferenceNode)) {
                throw new IllegalArgumentException("Invalid argument " + i2 + " to " + referenceNode.getName() + ": the argument must be a reference. Got " + expressionNode.toString());
            }
            ExpressionFunction.Instance lengthFunctionFor = lengthFunctionFor((ReferenceNode) expressionNode);
            if (!rankProfileTransformContext.rankProfile().getFunctions().containsKey(lengthFunctionFor.getName())) {
                rankProfileTransformContext.rankProfile().addFunction(lengthFunctionFor.getName(), List.of(), lengthFunctionFor.getExpressionString(), false);
            }
        }
    }

    private ExpressionNode createTokenSequenceExpr(int i, List<ExpressionNode> list) {
        OperationNode operationNode = new OperationNode(new ReferenceNode("d1"), Operator.smaller, createLengthExpr(i, list));
        ExpressionNode expressionNode = list.get(i);
        if (list.get(i) instanceof ReferenceNode) {
            expressionNode = createTokenExtractExpr(i, list);
        }
        return new IfNode(operationNode, expressionNode, i < list.size() - 1 ? createTokenSequenceExpr(i + 1, list) : ZERO);
    }

    private ExpressionNode createLengthExpr(int i, List<ExpressionNode> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < i + 1; i2++) {
            if (list.get(i2) instanceof ConstantNode) {
                arrayList.add(ONE);
            } else {
                ExpressionNode expressionNode = list.get(i2);
                if (expressionNode instanceof ReferenceNode) {
                    arrayList.add(new ReferenceNode(lengthFunctionFor((ReferenceNode) expressionNode).getName()));
                }
            }
            if (i2 >= 1) {
                arrayList2.add(Operator.plus);
            }
        }
        return (arrayList2.isEmpty() && arrayList.size() == 1) ? (ExpressionNode) arrayList.get(0) : new OperationNode(arrayList, arrayList2);
    }

    private ExpressionNode createTokenExtractExpr(int i, List<ExpressionNode> list) {
        EmbracedNode referenceNode;
        if (i >= 1) {
            referenceNode = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.minus, new EmbracedNode(createLengthExpr(i - 1, list))));
        } else {
            referenceNode = new ReferenceNode("d1");
        }
        return new TensorFunctionNode(new Slice(new TensorFunctionNode.ExpressionTensorFunction(list.get(i)), List.of(new Slice.DimensionValue("d0", TensorFunctionNode.wrapScalar(referenceNode)))));
    }
}
