package dev.langchain4j.community.model.dashscope;

import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
import com.alibaba.dashscope.embeddings.TextEmbeddingResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import dev.langchain4j.community.model.dashscope.spi.QwenEmbeddingModelBuilderFactory;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.spi.ServiceHelper;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:dev/langchain4j/community/model/dashscope/QwenEmbeddingModel.class */
public class QwenEmbeddingModel extends DimensionAwareEmbeddingModel {
    public static final String TYPE_KEY = "type";
    public static final String TYPE_QUERY = "query";
    public static final String TYPE_DOCUMENT = "document";
    private static final int BATCH_SIZE = 6;
    private final String apiKey;
    private final String modelName;
    private final TextEmbedding embedding;
    private Consumer<TextEmbeddingParam.TextEmbeddingParamBuilder<?, ?>> textEmbeddingParamCustomizer = textEmbeddingParamBuilder -> {
    };

    /* loaded from: input_file:dev/langchain4j/community/model/dashscope/QwenEmbeddingModel$QwenEmbeddingModelBuilder.class */
    public static class QwenEmbeddingModelBuilder {
        private String baseUrl;
        private String apiKey;
        private String modelName;

        public QwenEmbeddingModelBuilder baseUrl(String str) {
            this.baseUrl = str;
            return this;
        }

        public QwenEmbeddingModelBuilder apiKey(String str) {
            this.apiKey = str;
            return this;
        }

        public QwenEmbeddingModelBuilder modelName(String str) {
            this.modelName = str;
            return this;
        }

        public QwenEmbeddingModel build() {
            return new QwenEmbeddingModel(this.baseUrl, this.apiKey, this.modelName);
        }
    }

    public QwenEmbeddingModel(String str, String str2, String str3) {
        if (Utils.isNullOrBlank(str2)) {
            throw new IllegalArgumentException("DashScope api key must be defined. It can be generated here: https://dashscope.console.aliyun.com/apiKey");
        }
        this.modelName = Utils.isNullOrBlank(str3) ? QwenModelName.TEXT_EMBEDDING_V2 : str3;
        this.apiKey = str2;
        this.embedding = Utils.isNullOrBlank(str) ? new TextEmbedding() : new TextEmbedding(str);
    }

    private boolean containsDocuments(List<TextSegment> list) {
        Stream map = list.stream().map((v0) -> {
            return v0.metadata();
        }).map(metadata -> {
            return metadata.getString(TYPE_KEY);
        });
        String str = TYPE_DOCUMENT;
        return map.anyMatch(str::equalsIgnoreCase);
    }

    private boolean containsQueries(List<TextSegment> list) {
        Stream map = list.stream().map((v0) -> {
            return v0.metadata();
        }).map(metadata -> {
            return metadata.getString(TYPE_KEY);
        });
        String str = TYPE_QUERY;
        return map.anyMatch(str::equalsIgnoreCase);
    }

    private Response<List<Embedding>> embedTexts(List<TextSegment> list, TextEmbeddingParam.TextType textType) {
        int size = list.size();
        if (size < BATCH_SIZE) {
            return batchEmbedTexts(list, textType);
        }
        ArrayList arrayList = new ArrayList(size);
        TokenUsage tokenUsage = null;
        for (int i = 0; i < size; i += BATCH_SIZE) {
            Response<List<Embedding>> batchEmbedTexts = batchEmbedTexts(list.subList(i, Math.min(size, i + BATCH_SIZE)), textType);
            arrayList.addAll((Collection) batchEmbedTexts.content());
            tokenUsage = TokenUsage.sum(tokenUsage, batchEmbedTexts.tokenUsage());
        }
        return Response.from(arrayList, tokenUsage);
    }

    private Response<List<Embedding>> batchEmbedTexts(List<TextSegment> list, TextEmbeddingParam.TextType textType) {
        TextEmbeddingParam.TextEmbeddingParamBuilder<?, ?> texts = TextEmbeddingParam.builder().apiKey(this.apiKey).model(this.modelName).textType(textType).texts((Collection) list.stream().map((v0) -> {
            return v0.text();
        }).collect(Collectors.toList()));
        try {
            this.textEmbeddingParamCustomizer.accept(texts);
            TextEmbeddingResult call = this.embedding.call(texts.build());
            return Response.from((List) ((List) Optional.of(call).map((v0) -> {
                return v0.getOutput();
            }).map((v0) -> {
                return v0.getEmbeddings();
            }).orElse(Collections.emptyList())).stream().sorted(Comparator.comparing((v0) -> {
                return v0.getTextIndex();
            })).map((v0) -> {
                return v0.getEmbedding();
            }).map(list2 -> {
                return (List) list2.stream().map((v0) -> {
                    return v0.floatValue();
                }).collect(Collectors.toList());
            }).map(Embedding::from).collect(Collectors.toList()), new TokenUsage(call.getUsage().getTotalTokens()));
        } catch (NoApiKeyException e) {
            throw new IllegalArgumentException((Throwable) e);
        }
    }

    public Response<List<Embedding>> embedAll(List<TextSegment> list) {
        if (!containsQueries(list)) {
            return embedTexts(list, TextEmbeddingParam.TextType.DOCUMENT);
        }
        if (!containsDocuments(list)) {
            return embedTexts(list, TextEmbeddingParam.TextType.QUERY);
        }
        ArrayList arrayList = new ArrayList(list.size());
        Integer num = null;
        for (TextSegment textSegment : list) {
            Response<List<Embedding>> embedTexts = TYPE_QUERY.equalsIgnoreCase(textSegment.metadata().getString(TYPE_KEY)) ? embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.QUERY) : embedTexts(Collections.singletonList(textSegment), TextEmbeddingParam.TextType.DOCUMENT);
            arrayList.addAll((Collection) embedTexts.content());
            if (embedTexts.tokenUsage() != null) {
                num = num == null ? embedTexts.tokenUsage().inputTokenCount() : Integer.valueOf(num.intValue() + embedTexts.tokenUsage().inputTokenCount().intValue());
            }
        }
        return Response.from(arrayList, new TokenUsage(num));
    }

    public void setTextEmbeddingParamCustomizer(Consumer<TextEmbeddingParam.TextEmbeddingParamBuilder<?, ?>> consumer) {
        this.textEmbeddingParamCustomizer = (Consumer) ValidationUtils.ensureNotNull(consumer, "textEmbeddingParamCustomizer");
    }

    public static QwenEmbeddingModelBuilder builder() {
        Iterator it = ServiceHelper.loadFactories(QwenEmbeddingModelBuilderFactory.class).iterator();
        return it.hasNext() ? ((QwenEmbeddingModelBuilderFactory) it.next()).get() : new QwenEmbeddingModelBuilder();
    }
}
