package ai.vespa.llm.clients;

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModelException;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.secret.Secrets;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.openai.client.OpenAIClient;
import com.openai.client.OpenAIClientAsync;
import com.openai.client.okhttp.OpenAIOkHttpClient;
import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
import com.openai.core.JsonValue;
import com.openai.models.ChatModel;
import com.openai.models.ResponseFormatJsonSchema;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;

@Beta
/* loaded from: input_file:ai/vespa/llm/clients/OpenAI.class */
public class OpenAI extends ConfigurableLanguageModel {
    private static final String DEFAULT_MODEL = "gpt-4o-mini";
    private static final String DEFAULT_ENDPOINT = "https://api.openai.com/v1/";
    private static final String DEFAULT_API_KEY = "<YOUR_API_KEY>";
    private final Map<String, String> configOptions;
    OpenAIClient defaultSyncClient;
    String cachedSyncApiKey;
    String cachedSyncEndpoint;
    OpenAIClientAsync defaultAsyncClient;
    String cachedAsyncApiKey;
    String cachedAsyncEndpoint;

    @Inject
    public OpenAI(LlmClientConfig llmClientConfig, Secrets secrets) {
        super(llmClientConfig, secrets);
        this.configOptions = new HashMap();
        if (!llmClientConfig.model().isBlank()) {
            this.configOptions.put("model", llmClientConfig.model());
        }
        if (llmClientConfig.temperature() >= 0.0d) {
            this.configOptions.put("temperature", String.valueOf(llmClientConfig.temperature()));
        }
        if (llmClientConfig.maxTokens() >= 0) {
            this.configOptions.put("maxTokens", String.valueOf(llmClientConfig.maxTokens()));
        }
    }

    private InferenceParameters prepareParameters(InferenceParameters inferenceParameters) {
        setApiKey(inferenceParameters);
        setEndpoint(inferenceParameters);
        Map<String, String> map = this.configOptions;
        Objects.requireNonNull(map);
        return inferenceParameters.withDefaultOptions((v1) -> {
            return r1.get(v1);
        });
    }

    OpenAIClient getSyncClient(String str, String str2) {
        if (this.defaultSyncClient != null && str != null && str.equals(this.cachedSyncApiKey) && str2 != null && str2.equals(this.cachedSyncEndpoint)) {
            return this.defaultSyncClient;
        }
        this.defaultSyncClient = OpenAIOkHttpClient.builder().apiKey(str).baseUrl(str2).responseValidation(false).build();
        this.cachedSyncApiKey = str;
        this.cachedSyncEndpoint = str2;
        return this.defaultSyncClient;
    }

    OpenAIClientAsync getAsyncClient(String str, String str2) {
        if (this.defaultAsyncClient != null && str != null && str.equals(this.cachedAsyncApiKey) && str2 != null && str2.equals(this.cachedAsyncEndpoint)) {
            return this.defaultAsyncClient;
        }
        this.defaultAsyncClient = OpenAIOkHttpClientAsync.builder().apiKey(str).baseUrl(str2).responseValidation(false).build();
        this.cachedAsyncApiKey = str;
        this.cachedAsyncEndpoint = str2;
        return this.defaultAsyncClient;
    }

    public List<Completion> complete(Prompt prompt, InferenceParameters inferenceParameters) {
        InferenceParameters prepareParameters = prepareParameters(inferenceParameters);
        OpenAIClient syncClient = getSyncClient((String) prepareParameters.getApiKey().orElse(DEFAULT_API_KEY), (String) prepareParameters.getEndpoint().orElse(DEFAULT_ENDPOINT));
        return syncClient.chat().completions().create(getChatCompletionCreateParams(prepareParameters, prompt)).choices().stream().flatMap(choice -> {
            return choice.message().content().stream().map(str -> {
                return new Completion(str, mapFinishReason(choice.finishReason().toString()));
            });
        }).toList();
    }

    public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters inferenceParameters, Consumer<Completion> consumer) {
        InferenceParameters prepareParameters = prepareParameters(inferenceParameters);
        OpenAIClientAsync asyncClient = getAsyncClient((String) prepareParameters.getApiKey().orElse(DEFAULT_API_KEY), (String) prepareParameters.getEndpoint().orElse(DEFAULT_ENDPOINT));
        ChatCompletionCreateParams chatCompletionCreateParams = getChatCompletionCreateParams(prepareParameters, prompt);
        Completion.FinishReason[] finishReasonArr = {Completion.FinishReason.stop};
        CompletableFuture<Completion.FinishReason> completableFuture = new CompletableFuture<>();
        asyncClient.chat().completions().createStreaming(chatCompletionCreateParams).subscribe(chatCompletionChunk -> {
            chatCompletionChunk.choices().stream().flatMap(choice -> {
                choice.finishReason().ifPresent(finishReason -> {
                    finishReasonArr[0] = mapFinishReason(finishReason.toString());
                });
                return choice.delta().content().stream().map(str -> {
                    return new Completion(str, (Completion.FinishReason) choice.finishReason().map(finishReason2 -> {
                        return mapFinishReason(finishReason2.toString());
                    }).orElse(Completion.FinishReason.none));
                });
            }).forEach(consumer);
        }).onCompleteFuture().thenAccept(r6 -> {
            completableFuture.complete(finishReasonArr[0]);
        }).exceptionally(th -> {
            completableFuture.completeExceptionally(th);
            return null;
        });
        return completableFuture;
    }

    private ChatCompletionCreateParams getChatCompletionCreateParams(InferenceParameters inferenceParameters, Prompt prompt) {
        ChatCompletionCreateParams.Builder addUserMessage = ChatCompletionCreateParams.builder().model(ChatModel.of((String) inferenceParameters.get("model").map((v0) -> {
            return v0.toString();
        }).orElse(DEFAULT_MODEL))).addUserMessage(prompt.toString());
        Optional optional = inferenceParameters.getInt("maxTokens");
        Objects.requireNonNull(addUserMessage);
        optional.ifPresent((v1) -> {
            r1.maxCompletionTokens(v1);
        });
        Optional optional2 = inferenceParameters.getDouble("temperature");
        Objects.requireNonNull(addUserMessage);
        optional2.ifPresent(addUserMessage::temperature);
        Optional optional3 = inferenceParameters.getDouble("topp");
        Objects.requireNonNull(addUserMessage);
        optional3.ifPresent(addUserMessage::topP);
        Optional optional4 = inferenceParameters.getLong("seed");
        Objects.requireNonNull(addUserMessage);
        optional4.ifPresent(addUserMessage::seed);
        Optional optional5 = inferenceParameters.getInt("npredict");
        Objects.requireNonNull(addUserMessage);
        optional5.ifPresent((v1) -> {
            r1.n(v1);
        });
        Optional optional6 = inferenceParameters.getDouble("frequencypenalty");
        Objects.requireNonNull(addUserMessage);
        optional6.ifPresent(addUserMessage::frequencyPenalty);
        Optional optional7 = inferenceParameters.getDouble("precencepenalty");
        Objects.requireNonNull(addUserMessage);
        optional7.ifPresent(addUserMessage::presencePenalty);
        addResponseFormat(inferenceParameters, addUserMessage);
        return addUserMessage.build();
    }

    private void addResponseFormat(InferenceParameters inferenceParameters, ChatCompletionCreateParams.Builder builder) {
        inferenceParameters.get("json_schema").ifPresent(str -> {
            try {
                Map map = (Map) new ObjectMapper().readValue(str.toString(), new TypeReference<Map<String, Object>>() { // from class: ai.vespa.llm.clients.OpenAI.1
                });
                HashMap hashMap = new HashMap();
                map.forEach((str, obj) -> {
                    hashMap.put(str, JsonValue.from(obj));
                });
                builder.responseFormat(ResponseFormatJsonSchema.builder().jsonSchema(ResponseFormatJsonSchema.JsonSchema.builder().name("structured-output").schema(ResponseFormatJsonSchema.JsonSchema.Schema.builder().putAllAdditionalProperties(hashMap).build()).build()).build());
            } catch (Exception e) {
                throw new LanguageModelException(400, "Failed to parse JSON schema:\n" + str.toString() + "\n" + e.getMessage(), e);
            }
        });
    }

    private Completion.FinishReason mapFinishReason(String str) {
        if (str == null) {
            return Completion.FinishReason.none;
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case -1106363674:
                if (str.equals("length")) {
                    z = true;
                    break;
                }
                break;
            case -25949074:
                if (str.equals("tool_calls")) {
                    z = 3;
                    break;
                }
                break;
            case 3387192:
                if (str.equals("none")) {
                    z = 5;
                    break;
                }
                break;
            case 3540994:
                if (str.equals("stop")) {
                    z = false;
                    break;
                }
                break;
            case 96784904:
                if (str.equals("error")) {
                    z = 6;
                    break;
                }
                break;
            case 124602878:
                if (str.equals("content_filter")) {
                    z = 2;
                    break;
                }
                break;
            case 2053138021:
                if (str.equals("function_call")) {
                    z = 4;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return Completion.FinishReason.stop;
            case true:
                return Completion.FinishReason.length;
            case true:
                return Completion.FinishReason.content_filter;
            case true:
                return Completion.FinishReason.tool_calls;
            case true:
                return Completion.FinishReason.function_call;
            case true:
                return Completion.FinishReason.none;
            case true:
                throw new IllegalStateException("OpenAI-client returned finish_reason=error");
            default:
                return Completion.FinishReason.other;
        }
    }
}
