package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.ArrayDataType;
import com.yahoo.document.DataType;
import com.yahoo.document.DocumentType;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.Array;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.language.Linguistics;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.class */
public class EmbedExpression extends Expression {
    private final Linguistics linguistics;
    private final Embedder embedder;
    private final String embedderId;
    private final List<String> embedderArguments;
    private String destination;

    public EmbedExpression(Linguistics linguistics, Map<String, Embedder> map, String str, List<String> list) {
        this.linguistics = linguistics;
        this.embedderId = str;
        this.embedderArguments = List.copyOf(list);
        boolean z = (str == null || str.isEmpty()) ? false : true;
        if (map.isEmpty()) {
            throw new IllegalStateException("No embedders provided");
        }
        if (map.size() == 1 && !z) {
            this.embedder = map.entrySet().stream().findFirst().get().getValue();
            return;
        }
        if (map.size() > 1 && !z) {
            this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. Valid embedders are " + validEmbedders(map));
        } else if (map.containsKey(str)) {
            this.embedder = map.get(str);
        } else {
            this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + str + "'. Valid embedders are " + validEmbedders(map));
        }
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    public DataType setInputType(DataType dataType, VerificationContext verificationContext) {
        super.setInputType(dataType, verificationContext);
        DataType outputType = getOutputType(verificationContext);
        validateInputAndOutput(dataType, outputType);
        return outputType;
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    public DataType setOutputType(DataType dataType, VerificationContext verificationContext) {
        super.setOutputType(null, dataType, TensorDataType.any(), verificationContext);
        DataType inputType = getInputType(verificationContext);
        validateInputAndOutput(inputType, dataType);
        return inputType;
    }

    private void validateInputAndOutput(DataType dataType, DataType dataType2) {
        if (dataType != null && !dataType.isAssignableTo(DataType.STRING) && (!(dataType instanceof ArrayDataType) || !((ArrayDataType) dataType).getNestedType().isAssignableTo(DataType.STRING))) {
            invalid("This requires either a string or array<string> input type, but got " + dataType.getName());
        }
        if (dataType2 != null) {
            TensorType targetTensor = toTargetTensor(dataType2);
            if (!validTarget(targetTensor)) {
                invalid("The embedding target field must either be a dense 1d tensor, a mapped 1d tensor, a mapped 2d tensor, an array of dense 1d tensors, or a mixed 2d or 3d tensor");
            }
            if (targetTensor.rank() == 2 && targetTensor.mappedSubtype().rank() == 2) {
                if (this.embedderArguments.size() != 1) {
                    invalid("When the embedding target field is a 2d mapped tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed splade paragraph | ...");
                }
                if (!targetTensor.mappedSubtype().dimensionNames().contains(this.embedderArguments.get(0))) {
                    invalid("The dimension '" + this.embedderArguments.get(0) + "' given to embed is not a sparse dimension of the target type " + String.valueOf(targetTensor));
                }
            }
            if (targetTensor.rank() == 3) {
                if (this.embedderArguments.size() != 1) {
                    invalid("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...");
                }
                if (!targetTensor.mappedSubtype().dimensionNames().contains(this.embedderArguments.get(0))) {
                    invalid("The dimension '" + this.embedderArguments.get(0) + "' given to embed is not a sparse dimension of the target type " + String.valueOf(targetTensor));
                }
            }
        }
        if (dataType == null || dataType2 == null) {
            return;
        }
        TensorType targetTensor2 = toTargetTensor(dataType2);
        if (dataType.isAssignableTo(DataType.STRING) && targetTensor2.rank() != 1 && (targetTensor2.rank() != 2 || targetTensor2.mappedSubtype().rank() <= 0)) {
            invalid("Input is a string, so output must be a rank 1 tensor, or a rank 2 tensor with one mapped dimension, but got " + String.valueOf(targetTensor2));
        }
        if (dataType instanceof ArrayDataType) {
            if (targetTensor2.rank() <= 1 || targetTensor2.mappedSubtype().rank() <= 0) {
                invalid("Input is an array, so output must be a rank 2 or 3 tensor with at least one mapped dimension, but got " + String.valueOf(targetTensor2));
            }
        }
    }

    private void invalid(String str) {
        throw new VerificationException(this, str);
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    public void setStatementOutput(DocumentType documentType, Field field) {
        this.destination = documentType.getName() + "." + field.getName();
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    protected void doVerify(VerificationContext verificationContext) {
        verificationContext.setCurrentType(new TensorDataType(toTargetTensor(getOutputType(verificationContext))));
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    protected void doExecute(ExecutionContext executionContext) {
        Tensor embedArrayValue;
        if (executionContext.getCurrentValue() == null) {
            return;
        }
        if (executionContext.getCurrentValue().getDataType() == DataType.STRING) {
            embedArrayValue = embedSingleValue(executionContext);
        } else {
            ArrayDataType dataType = executionContext.getCurrentValue().getDataType();
            if (!(dataType instanceof ArrayDataType) || dataType.getNestedType() != DataType.STRING) {
                throw new IllegalArgumentException("Embedding can only be done on string or string array fields, not " + String.valueOf(executionContext.getCurrentValue().getDataType()));
            }
            embedArrayValue = embedArrayValue(getOutputTensorType(), executionContext);
        }
        executionContext.setCurrentValue(new TensorFieldValue(embedArrayValue));
    }

    private Tensor embedSingleValue(ExecutionContext executionContext) {
        return embed(executionContext.getCurrentValue().getString(), getOutputTensorType(), executionContext);
    }

    private Tensor embedArrayValue(TensorType tensorType, ExecutionContext executionContext) {
        Array<StringFieldValue> array = (Array) executionContext.getCurrentValue();
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        if (tensorType.rank() != 2) {
            embedArrayValueToRank3Tensor(array, of, executionContext);
        } else if (tensorType.indexedSubtype().rank() == 1) {
            embedArrayValueToRank2Tensor(array, of, executionContext);
        } else {
            if (tensorType.mappedSubtype().rank() != 2) {
                throw new IllegalArgumentException("Embedding an array into " + String.valueOf(tensorType) + " is not supported");
            }
            embedArrayValueToRank2MappedTensor(array, of, executionContext);
        }
        return of.build();
    }

    private void embedArrayValueToRank2Tensor(Array<StringFieldValue> array, Tensor.Builder builder, ExecutionContext executionContext) {
        String name = ((TensorType.Dimension) builder.type().mappedSubtype().dimensions().get(0)).name();
        String name2 = ((TensorType.Dimension) builder.type().indexedSubtype().dimensions().get(0)).name();
        for (int i = 0; i < array.size(); i++) {
            Iterator cellIterator = embed(array.get(i).getString(), builder.type().indexedSubtype(), executionContext).cellIterator();
            while (cellIterator.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell) cellIterator.next();
                builder.cell().label(name, i).label(name2, cell.getKey().numericLabel(0)).value(cell.getValue().doubleValue());
            }
        }
    }

    private void embedArrayValueToRank3Tensor(Array<StringFieldValue> array, Tensor.Builder builder, ExecutionContext executionContext) {
        String str = this.embedderArguments.get(0);
        String str2 = (String) builder.type().mappedSubtype().dimensionNames().stream().filter(str3 -> {
            return !str3.equals(str);
        }).findFirst().get();
        String name = ((TensorType.Dimension) builder.type().indexedSubtype().dimensions().get(0)).name();
        TensorType build = new TensorType.Builder(builder.type().valueType()).mapped(str2).indexed(name, ((Long) ((TensorType.Dimension) builder.type().indexedSubtype().dimensions().get(0)).size().get()).longValue()).build();
        int indexOfDimensionAsInt = build.indexOfDimensionAsInt(str2);
        int indexOfDimensionAsInt2 = build.indexOfDimensionAsInt(name);
        for (int i = 0; i < array.size(); i++) {
            Iterator cellIterator = embed(array.get(i).getString(), build, executionContext).cellIterator();
            while (cellIterator.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell) cellIterator.next();
                builder.cell().label(str, i).label(str2, cell.getKey().label(indexOfDimensionAsInt)).label(name, cell.getKey().numericLabel(indexOfDimensionAsInt2)).value(cell.getValue().doubleValue());
            }
        }
    }

    private void embedArrayValueToRank2MappedTensor(Array<StringFieldValue> array, Tensor.Builder builder, ExecutionContext executionContext) {
        String str = this.embedderArguments.get(0);
        String str2 = (String) getOutputTensorType().mappedSubtype().dimensionNames().stream().filter(str3 -> {
            return !str3.equals(str);
        }).findFirst().get();
        TensorType build = new TensorType.Builder(getOutputTensorType().valueType()).mapped(str2).build();
        int indexOfDimensionAsInt = build.indexOfDimensionAsInt(str2);
        for (int i = 0; i < array.size(); i++) {
            Iterator cellIterator = embed(array.get(i).getString(), build, executionContext).cellIterator();
            while (cellIterator.hasNext()) {
                Tensor.Cell cell = (Tensor.Cell) cellIterator.next();
                builder.cell().label(str, i).label(str2, cell.getKey().label(indexOfDimensionAsInt)).value(cell.getValue().doubleValue());
            }
        }
    }

    private Tensor embed(String str, TensorType tensorType, ExecutionContext executionContext) {
        return this.embedder.embed(str, new Embedder.Context(this.destination, executionContext.getCache()).setLanguage(executionContext.resolveLanguage(this.linguistics)).setEmbedderId(this.embedderId), tensorType);
    }

    private TensorType getOutputTensorType() {
        return getOutputType().getTensorType();
    }

    @Override // com.yahoo.vespa.indexinglanguage.expressions.Expression
    public DataType createdOutputType() {
        return getOutputType();
    }

    private static TensorType toTargetTensor(DataType dataType) {
        if (dataType instanceof ArrayDataType) {
            return toTargetTensor(dataType.getNestedType());
        }
        if (dataType instanceof TensorDataType) {
            return ((TensorDataType) dataType).getTensorType();
        }
        throw new IllegalArgumentException("Expected a tensor data type but got " + String.valueOf(dataType));
    }

    private boolean validTarget(TensorType tensorType) {
        if (tensorType.rank() == 1) {
            return true;
        }
        if (tensorType.rank() == 2 && tensorType.indexedSubtype().rank() == 1) {
            return true;
        }
        if (tensorType.rank() == 2 && tensorType.mappedSubtype().rank() == 2) {
            return true;
        }
        return tensorType.rank() == 3 && tensorType.indexedSubtype().rank() == 1;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("embed");
        if (this.embedderId != null && !this.embedderId.isEmpty()) {
            sb.append(" ").append(this.embedderId);
        }
        this.embedderArguments.forEach(str -> {
            sb.append(" ").append(str);
        });
        return sb.toString();
    }

    public int hashCode() {
        return EmbedExpression.class.hashCode();
    }

    public boolean equals(Object obj) {
        return obj instanceof EmbedExpression;
    }

    private static String validEmbedders(Map<String, Embedder> map) {
        ArrayList arrayList = new ArrayList();
        map.forEach((str, embedder) -> {
            arrayList.add(str);
        });
        arrayList.sort(null);
        return String.join(", ", arrayList);
    }
}
