package com.yahoo.search.schema.internal;

import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:com/yahoo/search/schema/internal/TensorConverter.class */
public class TensorConverter {
    private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])");
    private final Map<String, Embedder> embedders;

    public TensorConverter(Map<String, Embedder> map) {
        this.embedders = map;
    }

    public Tensor convertTo(TensorType tensorType, String str, Object obj, Language language) {
        Tensor tensor = toTensor(tensorType, obj, new Embedder.Context(str).setLanguage(language));
        if (tensor == null) {
            return null;
        }
        if (tensor.type().isAssignableTo(tensorType)) {
            return tensor;
        }
        throw new IllegalArgumentException("Require a tensor of type " + tensorType);
    }

    private Tensor toTensor(TensorType tensorType, Object obj, Embedder.Context context) {
        if (obj instanceof Tensor) {
            return (Tensor) obj;
        }
        if ((obj instanceof String) && isEmbed((String) obj)) {
            return embed((String) obj, tensorType, context);
        }
        if (obj instanceof String) {
            return Tensor.from(tensorType, (String) obj);
        }
        return null;
    }

    static boolean isEmbed(String str) {
        return str.startsWith("embed(");
    }

    private Tensor embed(String str, TensorType tensorType, Embedder.Context context) {
        Embedder value;
        if (!str.endsWith(")")) {
            throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
        }
        String substring = str.substring("embed(".length(), str.length() - 1);
        Matcher matcher = embedderArgumentRegexp.matcher(substring);
        if (matcher.matches()) {
            String group = matcher.group(1);
            substring = matcher.group(2);
            if (!this.embedders.containsKey(group)) {
                throw new IllegalArgumentException("Can't find embedder '" + group + "'. Valid embedders are " + validEmbedders(this.embedders));
            }
            value = this.embedders.get(group);
        } else {
            if (this.embedders.size() == 0) {
                throw new IllegalStateException("No embedders provided");
            }
            if (this.embedders.size() > 1) {
                throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. Valid embedders are " + validEmbedders(this.embedders));
            }
            value = this.embedders.entrySet().stream().findFirst().get().getValue();
        }
        return value.embed(removeQuotes(substring), context, tensorType);
    }

    private static String removeQuotes(String str) {
        return (str.startsWith("'") && str.endsWith("'")) ? str.substring(1, str.length() - 1) : (str.startsWith("\"") && str.endsWith("\"")) ? str.substring(1, str.length() - 1) : str;
    }

    private static String validEmbedders(Map<String, Embedder> map) {
        ArrayList arrayList = new ArrayList();
        map.forEach((str, embedder) -> {
            arrayList.add(str);
        });
        arrayList.sort(null);
        return String.join(",", arrayList);
    }
}
