/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.llm.generation;

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import ai.vespa.llm.generation.FieldGeneratorUtils;
import ai.vespa.llm.generation.LanguageModelFieldGeneratorConfig;
import ai.vespa.llm.generation.LanguageModelUtils;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.document.DataType;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.language.process.FieldGenerator;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.List;
import java.util.logging.Logger;

public class LanguageModelFieldGenerator
extends AbstractComponent
implements FieldGenerator {
    private static final Logger logger = Logger.getLogger(LanguageModelFieldGenerator.class.getName());
    private final LanguageModel languageModel;
    private final LanguageModelFieldGeneratorConfig config;
    private final String promptTemplate;

    @Inject
    public LanguageModelFieldGenerator(LanguageModelFieldGeneratorConfig config, ComponentRegistry<LanguageModel> languageModels) {
        this.languageModel = LanguageModelUtils.findLanguageModel(config.providerId(), languageModels, logger);
        this.config = config;
        this.promptTemplate = this.loadPromptTemplate(config);
    }

    private String loadPromptTemplate(LanguageModelFieldGeneratorConfig config) {
        if (config.promptTemplate() != null && !config.promptTemplate().isEmpty()) {
            return config.promptTemplate();
        }
        if (config.promptTemplateFile().isPresent()) {
            Path path = config.promptTemplateFile().get();
            try {
                String promptTemplate = new String(Files.readAllBytes(path));
                if (promptTemplate.isEmpty()) {
                    throw new IllegalArgumentException("Prompt template file is empty: " + String.valueOf(path));
                }
                return promptTemplate;
            }
            catch (IOException e) {
                throw new IllegalArgumentException("Could not read prompt template file: " + String.valueOf(path), e);
            }
        }
        return null;
    }

    public FieldValue generate(Prompt prompt, FieldGenerator.Context context) {
        return (FieldValue)context.computeCachedValueIfAbsent((Object)new CacheKey(this, prompt, context.getDestination(), context.getTargetType()), () -> this.computeGeneration(prompt, context.getDestination(), context.getTargetType()));
    }

    private FieldValue computeGeneration(Prompt prompt, String destination, DataType targetType) {
        StringFieldValue generatedFieldValue;
        HashMap<String, String> options = new HashMap<String, String>();
        String jsonSchema = null;
        if (this.config.responseFormatType() == LanguageModelFieldGeneratorConfig.ResponseFormatType.JSON) {
            jsonSchema = FieldGeneratorUtils.generateJsonSchemaForField(destination, targetType);
            options.put("json_schema", jsonSchema);
        }
        String expandedPrompt = LanguageModelUtils.expandPrompt(prompt.asString(), this.promptTemplate, jsonSchema);
        List completions = this.languageModel.complete((Prompt)StringPrompt.from((String)expandedPrompt), new InferenceParameters(options::get));
        Completion firstCompletion = (Completion)completions.get(0);
        String generatedText = firstCompletion.text();
        if (this.config.responseFormatType() == LanguageModelFieldGeneratorConfig.ResponseFormatType.JSON) {
            try {
                generatedFieldValue = FieldGeneratorUtils.parseJsonFieldValue(generatedText, destination, targetType);
            }
            catch (IllegalArgumentException e) {
                generatedFieldValue = switch (this.config.invalidResponseFormatPolicy()) {
                    default -> throw new IncompatibleClassChangeError();
                    case LanguageModelFieldGeneratorConfig.InvalidResponseFormatPolicy.Enum.DISCARD -> null;
                    case LanguageModelFieldGeneratorConfig.InvalidResponseFormatPolicy.Enum.WARN -> {
                        logger.warning(e.getMessage());
                        yield null;
                    }
                    case LanguageModelFieldGeneratorConfig.InvalidResponseFormatPolicy.Enum.FAIL -> throw e;
                };
            }
        } else {
            generatedFieldValue = new StringFieldValue(generatedText);
        }
        return generatedFieldValue;
    }

    private record CacheKey(LanguageModelFieldGenerator generator, Prompt prompt, String destination, DataType targetType) {
    }
}

