package com.github.tjake.jlama.safetensors.prompt;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.github.tjake.jlama.safetensors.tokenizer.TokenizerModel;
import com.github.tjake.jlama.util.JsonSupport;
import com.hubspot.jinjava.Jinjava;
import com.hubspot.jinjava.JinjavaConfig;
import com.hubspot.jinjava.LegacyOverrides;
import com.hubspot.jinjava.interpret.RenderResult;
import com.hubspot.jinjava.lib.fn.ELFunctionDefinition;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/tjake/jlama/safetensors/prompt/PromptSupport.class */
public class PromptSupport {
    private static final Logger logger = LoggerFactory.getLogger(PromptSupport.class);
    private static final Jinjava jinjava = new Jinjava(JinjavaConfig.newBuilder().withTrimBlocks(true).withLstripBlocks(true).withLegacyOverrides(LegacyOverrides.newBuilder().withParseWhitespaceControlStrictly(true).withUseTrimmingForNotesAndExpressions(true).withUseSnakeCasePropertyNaming(true).withKeepNullableLoopValues(true).build()).withObjectMapper(new ObjectMapper().enable(SerializationFeature.INDENT_OUTPUT).setDefaultPrettyPrinter(JsonSupport.JlamaPrettyPrinter.INSTANCE)).build());
    private final TokenizerModel m;

    /* loaded from: input_file:com/github/tjake/jlama/safetensors/prompt/PromptSupport$Builder.class */
    public static class Builder {
        private final TokenizerModel m;
        private PromptType type = PromptType.DEFAULT;
        private boolean addGenerationPrompt = true;
        private List<Message> messages = new ArrayList(2);

        private Builder(TokenizerModel tokenizerModel) {
            this.m = tokenizerModel;
        }

        public Builder usePromptType(PromptType promptType) {
            this.type = promptType;
            return this;
        }

        public Builder addGenerationPrompt(boolean z) {
            this.addGenerationPrompt = z;
            return this;
        }

        public Builder addUserMessage(String str) {
            this.messages.add(new Message(str, PromptRole.USER));
            return this;
        }

        public Builder addToolResult(ToolResult toolResult) {
            this.messages.add(new Message(toolResult));
            return this;
        }

        public Builder addToolCall(ToolCall toolCall) {
            this.messages.add(new Message(toolCall));
            return this;
        }

        public Builder addSystemMessage(String str) {
            this.messages.add(new Message(str, PromptRole.SYSTEM));
            return this;
        }

        public Builder addAssistantMessage(String str) {
            this.messages.add(new Message(str, PromptRole.ASSISTANT));
            return this;
        }

        public PromptContext build() {
            return build(Optional.empty());
        }

        public PromptContext build(List<Tool> list) {
            return build(Optional.of(list));
        }

        public PromptContext build(Tool... toolArr) {
            return build(Optional.of(List.of((Object[]) toolArr)));
        }

        private PromptContext build(Optional<List<Tool>> optional) {
            if (this.messages.isEmpty()) {
                throw new IllegalArgumentException("No messages to generate prompt");
            }
            if (this.m.promptTemplates().isEmpty()) {
                throw new UnsupportedOperationException("Prompt templates are not available for this model");
            }
            String str = (String) this.m.promptTemplates().map(map -> {
                return (String) map.get(this.type.name().toLowerCase());
            }).orElseThrow(() -> {
                return new UnsupportedOperationException("Prompt template not available for type: " + String.valueOf(this.type));
            });
            if (optional.isPresent() && !optional.get().isEmpty() && !this.m.hasToolSupport()) {
                PromptSupport.logger.warn("This model does not support tools, but tools are specified");
            }
            HashMap hashMap = new HashMap();
            hashMap.putAll(Map.of("messages", this.messages.stream().map((v0) -> {
                return v0.toMap();
            }).toList(), "add_generation_prompt", Boolean.valueOf(this.addGenerationPrompt), "eos_token", this.m.eosToken(), "bos_token", ""));
            optional.ifPresent(list -> {
                hashMap.put("tools", list);
            });
            RenderResult renderForResult = PromptSupport.jinjava.renderForResult(str, hashMap);
            if (renderForResult.hasErrors()) {
                PromptSupport.logger.debug("Prompt template errors: " + String.valueOf(renderForResult.getErrors()));
            }
            return new PromptContext(renderForResult.getOutput(), optional);
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/safetensors/prompt/PromptSupport$InnerToolCall.class */
    static class InnerToolCall {
        private final ToolCall call;

        private InnerToolCall(ToolCall toolCall) {
            this.call = toolCall;
        }

        public Map<String, Object> arguments() {
            return this.call.getParameters();
        }

        public String name() {
            return this.call.getName();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/github/tjake/jlama/safetensors/prompt/PromptSupport$Message.class */
    public static class Message {
        private final Object content;
        private final PromptRole role;
        private final ToolCallFunction toolCalls;
        private final String toolCallId;

        private Message(Object obj, PromptRole promptRole) {
            this.content = obj;
            this.role = promptRole;
            this.toolCalls = null;
            this.toolCallId = null;
        }

        private Message(ToolCall toolCall) {
            this.content = null;
            this.role = PromptRole.TOOL_CALL;
            this.toolCalls = new ToolCallFunction(toolCall);
            this.toolCallId = toolCall.getId();
        }

        private Message(ToolResult toolResult) {
            this.content = toolResult.toJson();
            this.toolCalls = null;
            this.role = PromptRole.TOOL;
            this.toolCallId = toolResult.getToolCallId();
        }

        public Object getContent() {
            return this.content;
        }

        public Map toMap() {
            HashMap hashMap = new HashMap();
            hashMap.put("role", this.role.name().toLowerCase());
            hashMap.put("content", this.content == null ? "" : this.content);
            if (this.toolCalls != null) {
                hashMap.put("tool_calls", List.of(this.toolCalls.toMap()));
            }
            if (this.toolCallId != null) {
                hashMap.put(ToolResult.JSON_PROPERTY_TOOL_ID, this.toolCallId);
            }
            return hashMap;
        }

        public String getRole() {
            return this.role.name().toLowerCase();
        }

        public List<ToolCallFunction> toolCalls() {
            if (this.toolCalls == null) {
                return null;
            }
            return List.of(this.toolCalls);
        }
    }

    /* loaded from: input_file:com/github/tjake/jlama/safetensors/prompt/PromptSupport$PromptRole.class */
    private enum PromptRole {
        USER,
        SYSTEM,
        ASSISTANT,
        TOOL,
        TOOL_CALL
    }

    /* loaded from: input_file:com/github/tjake/jlama/safetensors/prompt/PromptSupport$PromptType.class */
    private enum PromptType {
        DEFAULT,
        TOOL,
        RAG
    }

    /* loaded from: input_file:com/github/tjake/jlama/safetensors/prompt/PromptSupport$ToolCallFunction.class */
    static class ToolCallFunction {
        private final ToolCall call;

        private ToolCallFunction(ToolCall toolCall) {
            this.call = toolCall;
        }

        public InnerToolCall function() {
            return new InnerToolCall(this.call);
        }

        public Map toMap() {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            linkedHashMap.put("name", this.call.getName());
            linkedHashMap.put("arguments", this.call.getParameters());
            return Map.of(Tool.JSON_PROPERTY_FUNCTION, linkedHashMap, "id", this.call.getId());
        }
    }

    public PromptSupport(TokenizerModel tokenizerModel) {
        this.m = tokenizerModel;
    }

    public Builder builder() {
        return new Builder(this.m);
    }

    public static void raiseException(String str) {
        logger.debug("Prompt template error: " + str);
    }

    static {
        jinjava.getGlobalContext().registerFunction(new ELFunctionDefinition("", "raise_exception", PromptSupport.class, "raiseException", new Class[]{String.class}));
    }
}
