package com.agentsflex.llm.spark;

import com.agentsflex.core.document.Document;
import com.agentsflex.core.llm.BaseLlm;
import com.agentsflex.core.llm.ChatContext;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.llm.StreamResponseListener;
import com.agentsflex.core.llm.client.BaseLlmClientListener;
import com.agentsflex.core.llm.client.HttpClient;
import com.agentsflex.core.llm.client.impl.WebSocketClient;
import com.agentsflex.core.llm.embedding.EmbeddingOptions;
import com.agentsflex.core.llm.response.AbstractBaseMessageResponse;
import com.agentsflex.core.llm.response.AiMessageResponse;
import com.agentsflex.core.message.AiMessage;
import com.agentsflex.core.parser.AiMessageParser;
import com.agentsflex.core.prompt.Prompt;
import com.agentsflex.core.store.VectorData;
import com.agentsflex.core.util.SleepUtil;
import com.agentsflex.core.util.StringUtil;
import com.alibaba.fastjson.JSONPath;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Base64;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/agentsflex/llm/spark/SparkLlm.class */
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
    private static final Logger logger = LoggerFactory.getLogger(SparkLlm.class);
    public AiMessageParser aiMessageParser;
    private final HttpClient httpClient;

    public SparkLlm(SparkLlmConfig sparkLlmConfig) {
        super(sparkLlmConfig);
        this.aiMessageParser = SparkLlmUtil.getAiMessageParser();
        this.httpClient = new HttpClient();
    }

    public VectorData embed(Document document, EmbeddingOptions embeddingOptions) {
        return embed(document, embeddingOptions, 0);
    }

    public VectorData embed(Document document, EmbeddingOptions embeddingOptions, int i) {
        String post = this.httpClient.post(SparkLlmUtil.createEmbedURL((SparkLlmConfig) this.config), (Map) null, SparkLlmUtil.embedPayload((SparkLlmConfig) this.config, document));
        if (StringUtil.noText(post)) {
            logger.error("Could not get embed data" + document);
            return null;
        }
        Integer num = (Integer) JSONPath.read(post, "$.header.code", Integer.class);
        if (num == null) {
            logger.error(post);
            return null;
        }
        if (num.intValue() != 0) {
            if (!num.equals(11202) || i >= 3) {
                logger.error(post);
                return null;
            }
            SleepUtil.sleep(((SparkLlmConfig) this.config).getConcurrencyLimitSleepMillis());
            return embed(document, embeddingOptions, i + 1);
        }
        String str = (String) JSONPath.read(post, "$.payload.feature.text", String.class);
        if (StringUtil.noText(str)) {
            logger.error(post);
            return null;
        }
        double[] dArr = new double[Base64.getDecoder().decode(str).length / 4];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = ByteBuffer.wrap(r0, i2 * 4, 4).order(ByteOrder.LITTLE_ENDIAN).getFloat();
        }
        VectorData vectorData = new VectorData();
        vectorData.setVector(dArr);
        return vectorData;
    }

    public AiMessageResponse chat(Prompt prompt, ChatOptions chatOptions) {
        Throwable[] thArr = new Throwable[1];
        AiMessageResponse[] aiMessageResponseArr = {null};
        waitResponse(prompt, chatOptions, aiMessageResponseArr, new CountDownLatch(1), thArr);
        AiMessageResponse aiMessageResponse = aiMessageResponseArr[0];
        Throwable th = thArr[0];
        if (aiMessageResponse == null) {
            if (th == null) {
                return null;
            }
            aiMessageResponse = new AiMessageResponse(prompt, "", (AiMessage) null);
        }
        if (th != null || aiMessageResponse.getMessage() == null) {
            aiMessageResponse.setError(true);
            if (th != null) {
                aiMessageResponse.setErrorMessage(th.getMessage());
            }
        } else {
            aiMessageResponse.setError(false);
        }
        return aiMessageResponse;
    }

    private void waitResponse(Prompt prompt, ChatOptions chatOptions, final AbstractBaseMessageResponse<?>[] abstractBaseMessageResponseArr, final CountDownLatch countDownLatch, final Throwable[] thArr) {
        chatStream(prompt, new StreamResponseListener() { // from class: com.agentsflex.llm.spark.SparkLlm.1
            /* JADX WARN: Multi-variable type inference failed */
            public void onMessage(ChatContext chatContext, AiMessageResponse aiMessageResponse) {
                AiMessage message = aiMessageResponse.getMessage();
                if (message != null) {
                    message.setContent(message.getFullContent());
                }
                abstractBaseMessageResponseArr[0] = aiMessageResponse;
            }

            public void onStop(ChatContext chatContext) {
                super.onStop(chatContext);
                countDownLatch.countDown();
            }

            public void onFailure(ChatContext chatContext, Throwable th) {
                logger.error(th.toString(), th);
                thArr[0] = th;
            }
        }, chatOptions);
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    public void chatStream(Prompt prompt, StreamResponseListener streamResponseListener, ChatOptions chatOptions) {
        WebSocketClient webSocketClient = new WebSocketClient();
        webSocketClient.start(SparkLlmUtil.createURL((SparkLlmConfig) this.config), (Map) null, SparkLlmUtil.promptToPayload(prompt, (SparkLlmConfig) this.config, chatOptions), new BaseLlmClientListener(this, webSocketClient, streamResponseListener, prompt, this.aiMessageParser), this.config);
    }
}
