package dev.langchain4j.model.jlama;

import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.prompt.PromptContext;
import com.github.tjake.jlama.safetensors.prompt.PromptSupport;
import com.github.tjake.jlama.safetensors.prompt.ToolCall;
import com.github.tjake.jlama.safetensors.prompt.ToolResult;
import com.github.tjake.jlama.util.JsonSupport;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.ContentType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.TextContent;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.ChatRequestValidator;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.jlama.JlamaModel;
import dev.langchain4j.model.jlama.spi.JlamaChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.spi.ServiceHelper;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;

/* loaded from: input_file:dev/langchain4j/model/jlama/JlamaChatModel.class */
public class JlamaChatModel implements ChatLanguageModel {
    private final AbstractModel model;
    private final Float temperature;
    private final Integer maxTokens;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: dev.langchain4j.model.jlama.JlamaChatModel$1, reason: invalid class name */
    /* loaded from: input_file:dev/langchain4j/model/jlama/JlamaChatModel$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$dev$langchain4j$data$message$ChatMessageType = new int[ChatMessageType.values().length];

        static {
            try {
                $SwitchMap$dev$langchain4j$data$message$ChatMessageType[ChatMessageType.SYSTEM.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$dev$langchain4j$data$message$ChatMessageType[ChatMessageType.USER.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$dev$langchain4j$data$message$ChatMessageType[ChatMessageType.AI.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$dev$langchain4j$data$message$ChatMessageType[ChatMessageType.TOOL_EXECUTION_RESULT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:dev/langchain4j/model/jlama/JlamaChatModel$JlamaChatModelBuilder.class */
    public static class JlamaChatModelBuilder {
        private Path modelCachePath;
        private String modelName;
        private String authToken;
        private Integer threadCount;
        private Boolean quantizeModelAtRuntime;
        private Path workingDirectory;
        private DType workingQuantizedType;
        private Float temperature;
        private Integer maxTokens;

        public JlamaChatModelBuilder modelCachePath(Path path) {
            this.modelCachePath = path;
            return this;
        }

        public JlamaChatModelBuilder modelName(String str) {
            this.modelName = str;
            return this;
        }

        public JlamaChatModelBuilder authToken(String str) {
            this.authToken = str;
            return this;
        }

        public JlamaChatModelBuilder threadCount(Integer num) {
            this.threadCount = num;
            return this;
        }

        public JlamaChatModelBuilder quantizeModelAtRuntime(Boolean bool) {
            this.quantizeModelAtRuntime = bool;
            return this;
        }

        public JlamaChatModelBuilder workingDirectory(Path path) {
            this.workingDirectory = path;
            return this;
        }

        public JlamaChatModelBuilder workingQuantizedType(DType dType) {
            this.workingQuantizedType = dType;
            return this;
        }

        public JlamaChatModelBuilder temperature(Float f) {
            this.temperature = f;
            return this;
        }

        public JlamaChatModelBuilder maxTokens(Integer num) {
            this.maxTokens = num;
            return this;
        }

        public JlamaChatModel build() {
            return new JlamaChatModel(this.modelCachePath, this.modelName, this.authToken, this.threadCount, this.quantizeModelAtRuntime, this.workingDirectory, this.workingQuantizedType, this.temperature, this.maxTokens);
        }

        public String toString() {
            return "JlamaChatModel.JlamaChatModelBuilder(modelCachePath=" + String.valueOf(this.modelCachePath) + ", modelName=" + this.modelName + ", authToken=" + this.authToken + ", threadCount=" + this.threadCount + ", quantizeModelAtRuntime=" + this.quantizeModelAtRuntime + ", workingDirectory=" + String.valueOf(this.workingDirectory) + ", workingQuantizedType=" + String.valueOf(this.workingQuantizedType) + ", temperature=" + this.temperature + ", maxTokens=" + this.maxTokens + ")";
        }
    }

    public JlamaChatModel(Path path, String str, String str2, Integer num, Boolean bool, Path path2, DType dType, Float f, Integer num2) {
        JlamaModelRegistry orCreate = JlamaModelRegistry.getOrCreate(path);
        JlamaModel.Loader loader = ((JlamaModel) RetryUtils.withRetryMappingExceptions(() -> {
            return orCreate.downloadModel(str, Optional.ofNullable(str2));
        }, 3, JlamaExceptionMapper.INSTANCE)).loader();
        if (bool != null && bool.booleanValue()) {
            loader = loader.quantized();
        }
        loader = dType != null ? loader.workingQuantizationType(dType) : loader;
        loader = num != null ? loader.threadCount(num) : loader;
        this.model = (path2 != null ? loader.workingDirectory(path2) : loader).load();
        this.temperature = Float.valueOf(f == null ? 0.3f : f.floatValue());
        this.maxTokens = Integer.valueOf(num2 == null ? this.model.getConfig().contextLength : num2.intValue());
    }

    public static JlamaChatModelBuilder builder() {
        Iterator it = ServiceHelper.loadFactories(JlamaChatModelBuilderFactory.class).iterator();
        return it.hasNext() ? ((JlamaChatModelBuilderFactory) it.next()).get() : new JlamaChatModelBuilder();
    }

    public ChatResponse chat(ChatRequest chatRequest) {
        ChatRequestValidator.validateMessages(chatRequest.messages());
        ChatRequestParameters parameters = chatRequest.parameters();
        ChatRequestValidator.validateParameters(parameters);
        ChatRequestValidator.validate(parameters.toolChoice());
        ChatRequestValidator.validate(parameters.responseFormat());
        List<ToolSpecification> list = parameters.toolSpecifications();
        Response<AiMessage> generate = Utils.isNullOrEmpty(list) ? generate(chatRequest.messages()) : generate(chatRequest.messages(), list);
        return ChatResponse.builder().aiMessage((AiMessage) generate.content()).metadata(ChatResponseMetadata.builder().tokenUsage(generate.tokenUsage()).finishReason(generate.finishReason()).build()).build();
    }

    private Response<AiMessage> generate(List<ChatMessage> list) {
        return generate(list, List.of());
    }

    private Response<AiMessage> generate(List<ChatMessage> list, List<ToolSpecification> list2) {
        if (this.model.promptSupport().isEmpty()) {
            throw new UnsupportedOperationException("This model does not support chat generation");
        }
        PromptSupport.Builder builder = ((PromptSupport) this.model.promptSupport().get()).builder();
        Iterator<ChatMessage> it = list.iterator();
        while (it.hasNext()) {
            UserMessage userMessage = (ChatMessage) it.next();
            switch (AnonymousClass1.$SwitchMap$dev$langchain4j$data$message$ChatMessageType[userMessage.type().ordinal()]) {
                case 1:
                    builder.addSystemMessage(((SystemMessage) userMessage).text());
                    break;
                case 2:
                    StringBuilder sb = new StringBuilder();
                    for (TextContent textContent : userMessage.contents()) {
                        if (textContent.type() != ContentType.TEXT) {
                            throw new UnsupportedOperationException("Unsupported content type: " + String.valueOf(textContent.type()));
                        }
                        sb.append(textContent.text());
                    }
                    builder.addUserMessage(sb.toString());
                    break;
                case 3:
                    AiMessage aiMessage = (AiMessage) userMessage;
                    if (aiMessage.text() != null) {
                        builder.addAssistantMessage(aiMessage.text());
                    }
                    if (aiMessage.hasToolExecutionRequests()) {
                        for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                            builder.addToolCall(new ToolCall(toolExecutionRequest.name(), toolExecutionRequest.id(), (Map) Json.fromJson(toolExecutionRequest.arguments(), LinkedHashMap.class)));
                        }
                        break;
                    } else {
                        break;
                    }
                case 4:
                    ToolExecutionResultMessage toolExecutionResultMessage = (ToolExecutionResultMessage) userMessage;
                    builder.addToolResult(ToolResult.from(toolExecutionResultMessage.toolName(), toolExecutionResultMessage.id(), toolExecutionResultMessage.text()));
                    break;
                default:
                    throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(userMessage.type()));
            }
        }
        List list3 = list2.stream().map(JlamaModel::toTool).toList();
        PromptContext build = list3.isEmpty() ? builder.build() : builder.build(list3);
        Generator.Response response = (Generator.Response) JlamaExceptionMapper.INSTANCE.withExceptionMapper(() -> {
            return this.model.generate(UUID.randomUUID(), build, this.temperature.floatValue(), this.maxTokens.intValue(), (str, f) -> {
            });
        });
        return response.finishReason == Generator.FinishReason.TOOL_CALL ? Response.from(AiMessage.from(response.toolCalls.stream().map(toolCall -> {
            return ToolExecutionRequest.builder().name(toolCall.getName()).id(toolCall.getId()).arguments(JsonSupport.toJson(toolCall.getParameters())).build();
        }).toList()), new TokenUsage(Integer.valueOf(response.promptTokens), Integer.valueOf(response.generatedTokens)), JlamaLanguageModel.toFinishReason(response.finishReason)) : Response.from(AiMessage.from(response.responseText), new TokenUsage(Integer.valueOf(response.promptTokens), Integer.valueOf(response.generatedTokens)), JlamaLanguageModel.toFinishReason(response.finishReason));
    }
}
