package io.kestra.plugin.openai;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.theokanning.openai.Usage;
import com.theokanning.openai.completion.chat.ChatCompletionChoice;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.service.OpenAiService;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.core.models.executions.metrics.Counter;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.openai.AbstractTask;
import io.swagger.v3.oas.annotations.media.Schema;
import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.validation.constraints.NotNull;
import lombok.Generated;

@Plugin(examples = {@Example(title = "Based on a prompt input, generate a completion response and pass it to a downstream task", full = true, code = {"id: openAI", "namespace: dev", JsonProperty.USE_DEFAULT_NAME, "inputs:", "  - name: prompt", "    type: STRING", "    defaults: What is data orchestration?", JsonProperty.USE_DEFAULT_NAME, "tasks:", "  - id: completion", "    type: io.kestra.plugin.openai.ChatCompletion", "    apiKey: \"yourOpenAIapiKey\"", "    model: gpt-3.5-turbo-0613", "    prompt: \"{{inputs.prompt}}\"", JsonProperty.USE_DEFAULT_NAME, "  - id: response", "    type: io.kestra.core.tasks.debugs.Return", "    format: \"{{outputs.completion.choices[0].message.content}}\""})})
@Schema(title = "Given a prompt, get a response from an LLM using the [OpenAI’s Chat Completions API](https://platform.openai.com/docs/api-reference/chat/create)", description = "For more information, refer to the [Chat Completions API docs](https://platform.openai.com/docs/guides/gpt/chat-completions-api)")
/* loaded from: input_file:io/kestra/plugin/openai/ChatCompletion.class */
public class ChatCompletion extends AbstractTask implements RunnableTask<Output> {

    @Schema(title = "A list of messages comprising the conversation so far.", description = "Required if prompt is not set.")
    @PluginProperty
    private List<ChatMessage> messages;

    @Schema(title = "The prompt(s) to generate completions for. By default, this prompt will be sent as a `user` role.", description = "If not provided, make sure to set the `messages` property.")
    @PluginProperty
    private String prompt;

    @Schema(title = "What sampling temperature to use, between 0 and 2. Defaults to 1.")
    @PluginProperty
    private Double temperature;

    @Schema(title = "An alternative to sampling with temperature, where the model considers the results of the tokens with top_p probability mass. Defaults to 1.")
    @PluginProperty
    private Double topP;

    @Schema(title = "How many chat completion choices to generate for each input message. Defaults to 1.")
    private Integer n;

    @Schema(title = "Up to 4 sequences where the API will stop generating further tokens. Defaults to null.")
    @PluginProperty
    private List<String> stop;

    @Schema(title = "The maximum number of tokens to generate in the chat completion. No limits are set by default.")
    @PluginProperty
    private Integer maxTokens;

    @Schema(title = "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far. Defaults to 0.")
    @PluginProperty
    private Double presencePenalty;

    @Schema(title = "Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far. Defaults to 0.")
    @PluginProperty
    private Double frequencyPenalty;

    @Schema(title = "Modify the likelihood of specified tokens appearing in the completion. Defaults to null.")
    @PluginProperty
    private Map<String, Integer> logitBias;

    @NotNull
    @Schema(title = "ID of the model to use e.g. `'gpt-4'`", description = "See the [OpenAI model's documentation page](https://platform.openai.com/docs/models/overview) for more details.")
    @PluginProperty(dynamic = true)
    private String model;

    @Generated
    /* loaded from: input_file:io/kestra/plugin/openai/ChatCompletion$ChatCompletionBuilder.class */
    public static abstract class ChatCompletionBuilder<C extends ChatCompletion, B extends ChatCompletionBuilder<C, B>> extends AbstractTask.AbstractTaskBuilder<C, B> {

        @Generated
        private List<ChatMessage> messages;

        @Generated
        private String prompt;

        @Generated
        private Double temperature;

        @Generated
        private Double topP;

        @Generated
        private Integer n;

        @Generated
        private List<String> stop;

        @Generated
        private Integer maxTokens;

        @Generated
        private Double presencePenalty;

        @Generated
        private Double frequencyPenalty;

        @Generated
        private Map<String, Integer> logitBias;

        @Generated
        private String model;

        @Generated
        public B messages(List<ChatMessage> list) {
            this.messages = list;
            return mo329self();
        }

        @Generated
        public B prompt(String str) {
            this.prompt = str;
            return mo329self();
        }

        @Generated
        public B temperature(Double d) {
            this.temperature = d;
            return mo329self();
        }

        @Generated
        public B topP(Double d) {
            this.topP = d;
            return mo329self();
        }

        @Generated
        public B n(Integer num) {
            this.n = num;
            return mo329self();
        }

        @Generated
        public B stop(List<String> list) {
            this.stop = list;
            return mo329self();
        }

        @Generated
        public B maxTokens(Integer num) {
            this.maxTokens = num;
            return mo329self();
        }

        @Generated
        public B presencePenalty(Double d) {
            this.presencePenalty = d;
            return mo329self();
        }

        @Generated
        public B frequencyPenalty(Double d) {
            this.frequencyPenalty = d;
            return mo329self();
        }

        @Generated
        public B logitBias(Map<String, Integer> map) {
            this.logitBias = map;
            return mo329self();
        }

        @Generated
        public B model(String str) {
            this.model = str;
            return mo329self();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.kestra.plugin.openai.AbstractTask.AbstractTaskBuilder
        @Generated
        /* renamed from: self */
        public abstract B mo329self();

        @Override // io.kestra.plugin.openai.AbstractTask.AbstractTaskBuilder
        @Generated
        /* renamed from: build */
        public abstract C mo328build();

        @Override // io.kestra.plugin.openai.AbstractTask.AbstractTaskBuilder
        @Generated
        public String toString() {
            return "ChatCompletion.ChatCompletionBuilder(super=" + super.toString() + ", messages=" + this.messages + ", prompt=" + this.prompt + ", temperature=" + this.temperature + ", topP=" + this.topP + ", n=" + this.n + ", stop=" + this.stop + ", maxTokens=" + this.maxTokens + ", presencePenalty=" + this.presencePenalty + ", frequencyPenalty=" + this.frequencyPenalty + ", logitBias=" + this.logitBias + ", model=" + this.model + ")";
        }
    }

    @Generated
    /* loaded from: input_file:io/kestra/plugin/openai/ChatCompletion$ChatCompletionBuilderImpl.class */
    private static final class ChatCompletionBuilderImpl extends ChatCompletionBuilder<ChatCompletion, ChatCompletionBuilderImpl> {
        @Generated
        private ChatCompletionBuilderImpl() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // io.kestra.plugin.openai.ChatCompletion.ChatCompletionBuilder, io.kestra.plugin.openai.AbstractTask.AbstractTaskBuilder
        @Generated
        /* renamed from: self */
        public ChatCompletionBuilderImpl mo329self() {
            return this;
        }

        @Override // io.kestra.plugin.openai.ChatCompletion.ChatCompletionBuilder, io.kestra.plugin.openai.AbstractTask.AbstractTaskBuilder
        @Generated
        /* renamed from: build */
        public ChatCompletion mo328build() {
            return new ChatCompletion(this);
        }
    }

    /* loaded from: input_file:io/kestra/plugin/openai/ChatCompletion$Output.class */
    public static class Output implements io.kestra.core.models.tasks.Output {

        @Schema(title = "Unique ID assigned to this Chat Completion.")
        private String id;

        @Schema(title = "The type of object returned, should be \"chat.completion\".")
        private String object;

        @Schema(title = "The GPT model used.")
        private String model;

        @Schema(title = "A list of all generated completions.")
        private List<ChatCompletionChoice> choices;

        @Schema(title = "The API usage for this request.")
        private Usage usage;

        @Generated
        /* loaded from: input_file:io/kestra/plugin/openai/ChatCompletion$Output$OutputBuilder.class */
        public static class OutputBuilder {

            @Generated
            private String id;

            @Generated
            private String object;

            @Generated
            private String model;

            @Generated
            private List<ChatCompletionChoice> choices;

            @Generated
            private Usage usage;

            @Generated
            OutputBuilder() {
            }

            @Generated
            public OutputBuilder id(String str) {
                this.id = str;
                return this;
            }

            @Generated
            public OutputBuilder object(String str) {
                this.object = str;
                return this;
            }

            @Generated
            public OutputBuilder model(String str) {
                this.model = str;
                return this;
            }

            @Generated
            public OutputBuilder choices(List<ChatCompletionChoice> list) {
                this.choices = list;
                return this;
            }

            @Generated
            public OutputBuilder usage(Usage usage) {
                this.usage = usage;
                return this;
            }

            @Generated
            public Output build() {
                return new Output(this.id, this.object, this.model, this.choices, this.usage);
            }

            @Generated
            public String toString() {
                return "ChatCompletion.Output.OutputBuilder(id=" + this.id + ", object=" + this.object + ", model=" + this.model + ", choices=" + this.choices + ", usage=" + this.usage + ")";
            }
        }

        @Generated
        @ConstructorProperties({"id", "object", "model", "choices", "usage"})
        Output(String str, String str2, String str3, List<ChatCompletionChoice> list, Usage usage) {
            this.id = str;
            this.object = str2;
            this.model = str3;
            this.choices = list;
            this.usage = usage;
        }

        @Generated
        public static OutputBuilder builder() {
            return new OutputBuilder();
        }

        @Generated
        public String getId() {
            return this.id;
        }

        @Generated
        public String getObject() {
            return this.object;
        }

        @Generated
        public String getModel() {
            return this.model;
        }

        @Generated
        public List<ChatCompletionChoice> getChoices() {
            return this.choices;
        }

        @Generated
        public Usage getUsage() {
            return this.usage;
        }
    }

    /* renamed from: run, reason: merged with bridge method [inline-methods] */
    public Output m330run(RunContext runContext) throws Exception {
        OpenAiService client = client(runContext);
        if (this.messages == null && this.prompt == null) {
            throw new IllegalArgumentException("Either `messages` or `prompt` must be set");
        }
        List<String> render = this.stop != null ? runContext.render(this.stop) : null;
        String render2 = runContext.render(this.user);
        String render3 = runContext.render(this.model);
        ArrayList arrayList = new ArrayList();
        if (this.messages != null) {
            for (ChatMessage chatMessage : this.messages) {
                chatMessage.setContent(runContext.render(chatMessage.getContent()));
                arrayList.add(chatMessage);
            }
        }
        if (this.prompt != null) {
            arrayList.add(buildMessage("user", runContext.render(this.prompt)));
        }
        ChatCompletionResult createChatCompletion = client.createChatCompletion(ChatCompletionRequest.builder().messages(arrayList).model(render3).temperature(this.temperature).topP(this.topP).n(this.n).stop(render).maxTokens(this.maxTokens).presencePenalty(this.presencePenalty).frequencyPenalty(this.frequencyPenalty).logitBias(this.logitBias).user(render2).build());
        runContext.metric(Counter.of("usage.prompt_tokens", Long.valueOf(createChatCompletion.getUsage().getPromptTokens()), new String[0]));
        runContext.metric(Counter.of("usage.completion_tokens", Long.valueOf(createChatCompletion.getUsage().getCompletionTokens()), new String[0]));
        runContext.metric(Counter.of("usage.total_tokens", Long.valueOf(createChatCompletion.getUsage().getTotalTokens()), new String[0]));
        return Output.builder().id(createChatCompletion.getId()).object(createChatCompletion.getObject()).model(createChatCompletion.getModel()).choices(createChatCompletion.getChoices()).usage(createChatCompletion.getUsage()).build();
    }

    private ChatMessage buildMessage(String str, String str2) {
        ChatMessage chatMessage = new ChatMessage();
        chatMessage.setRole(str);
        chatMessage.setContent(str2);
        return chatMessage;
    }

    @Generated
    protected ChatCompletion(ChatCompletionBuilder<?, ?> chatCompletionBuilder) {
        super(chatCompletionBuilder);
        this.messages = ((ChatCompletionBuilder) chatCompletionBuilder).messages;
        this.prompt = ((ChatCompletionBuilder) chatCompletionBuilder).prompt;
        this.temperature = ((ChatCompletionBuilder) chatCompletionBuilder).temperature;
        this.topP = ((ChatCompletionBuilder) chatCompletionBuilder).topP;
        this.n = ((ChatCompletionBuilder) chatCompletionBuilder).n;
        this.stop = ((ChatCompletionBuilder) chatCompletionBuilder).stop;
        this.maxTokens = ((ChatCompletionBuilder) chatCompletionBuilder).maxTokens;
        this.presencePenalty = ((ChatCompletionBuilder) chatCompletionBuilder).presencePenalty;
        this.frequencyPenalty = ((ChatCompletionBuilder) chatCompletionBuilder).frequencyPenalty;
        this.logitBias = ((ChatCompletionBuilder) chatCompletionBuilder).logitBias;
        this.model = ((ChatCompletionBuilder) chatCompletionBuilder).model;
    }

    @Generated
    public static ChatCompletionBuilder<?, ?> builder() {
        return new ChatCompletionBuilderImpl();
    }

    @Override // io.kestra.plugin.openai.AbstractTask
    @Generated
    public String toString() {
        return "ChatCompletion(super=" + super.toString() + ", messages=" + getMessages() + ", prompt=" + getPrompt() + ", temperature=" + getTemperature() + ", topP=" + getTopP() + ", n=" + getN() + ", stop=" + getStop() + ", maxTokens=" + getMaxTokens() + ", presencePenalty=" + getPresencePenalty() + ", frequencyPenalty=" + getFrequencyPenalty() + ", logitBias=" + getLogitBias() + ", model=" + getModel() + ")";
    }

    @Override // io.kestra.plugin.openai.AbstractTask
    @Generated
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ChatCompletion)) {
            return false;
        }
        ChatCompletion chatCompletion = (ChatCompletion) obj;
        if (!chatCompletion.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        Double temperature = getTemperature();
        Double temperature2 = chatCompletion.getTemperature();
        if (temperature == null) {
            if (temperature2 != null) {
                return false;
            }
        } else if (!temperature.equals(temperature2)) {
            return false;
        }
        Double topP = getTopP();
        Double topP2 = chatCompletion.getTopP();
        if (topP == null) {
            if (topP2 != null) {
                return false;
            }
        } else if (!topP.equals(topP2)) {
            return false;
        }
        Integer n = getN();
        Integer n2 = chatCompletion.getN();
        if (n == null) {
            if (n2 != null) {
                return false;
            }
        } else if (!n.equals(n2)) {
            return false;
        }
        Integer maxTokens = getMaxTokens();
        Integer maxTokens2 = chatCompletion.getMaxTokens();
        if (maxTokens == null) {
            if (maxTokens2 != null) {
                return false;
            }
        } else if (!maxTokens.equals(maxTokens2)) {
            return false;
        }
        Double presencePenalty = getPresencePenalty();
        Double presencePenalty2 = chatCompletion.getPresencePenalty();
        if (presencePenalty == null) {
            if (presencePenalty2 != null) {
                return false;
            }
        } else if (!presencePenalty.equals(presencePenalty2)) {
            return false;
        }
        Double frequencyPenalty = getFrequencyPenalty();
        Double frequencyPenalty2 = chatCompletion.getFrequencyPenalty();
        if (frequencyPenalty == null) {
            if (frequencyPenalty2 != null) {
                return false;
            }
        } else if (!frequencyPenalty.equals(frequencyPenalty2)) {
            return false;
        }
        List<ChatMessage> messages = getMessages();
        List<ChatMessage> messages2 = chatCompletion.getMessages();
        if (messages == null) {
            if (messages2 != null) {
                return false;
            }
        } else if (!messages.equals(messages2)) {
            return false;
        }
        String prompt = getPrompt();
        String prompt2 = chatCompletion.getPrompt();
        if (prompt == null) {
            if (prompt2 != null) {
                return false;
            }
        } else if (!prompt.equals(prompt2)) {
            return false;
        }
        List<String> stop = getStop();
        List<String> stop2 = chatCompletion.getStop();
        if (stop == null) {
            if (stop2 != null) {
                return false;
            }
        } else if (!stop.equals(stop2)) {
            return false;
        }
        Map<String, Integer> logitBias = getLogitBias();
        Map<String, Integer> logitBias2 = chatCompletion.getLogitBias();
        if (logitBias == null) {
            if (logitBias2 != null) {
                return false;
            }
        } else if (!logitBias.equals(logitBias2)) {
            return false;
        }
        String model = getModel();
        String model2 = chatCompletion.getModel();
        return model == null ? model2 == null : model.equals(model2);
    }

    @Override // io.kestra.plugin.openai.AbstractTask
    @Generated
    protected boolean canEqual(Object obj) {
        return obj instanceof ChatCompletion;
    }

    @Override // io.kestra.plugin.openai.AbstractTask
    @Generated
    public int hashCode() {
        int hashCode = super.hashCode();
        Double temperature = getTemperature();
        int hashCode2 = (hashCode * 59) + (temperature == null ? 43 : temperature.hashCode());
        Double topP = getTopP();
        int hashCode3 = (hashCode2 * 59) + (topP == null ? 43 : topP.hashCode());
        Integer n = getN();
        int hashCode4 = (hashCode3 * 59) + (n == null ? 43 : n.hashCode());
        Integer maxTokens = getMaxTokens();
        int hashCode5 = (hashCode4 * 59) + (maxTokens == null ? 43 : maxTokens.hashCode());
        Double presencePenalty = getPresencePenalty();
        int hashCode6 = (hashCode5 * 59) + (presencePenalty == null ? 43 : presencePenalty.hashCode());
        Double frequencyPenalty = getFrequencyPenalty();
        int hashCode7 = (hashCode6 * 59) + (frequencyPenalty == null ? 43 : frequencyPenalty.hashCode());
        List<ChatMessage> messages = getMessages();
        int hashCode8 = (hashCode7 * 59) + (messages == null ? 43 : messages.hashCode());
        String prompt = getPrompt();
        int hashCode9 = (hashCode8 * 59) + (prompt == null ? 43 : prompt.hashCode());
        List<String> stop = getStop();
        int hashCode10 = (hashCode9 * 59) + (stop == null ? 43 : stop.hashCode());
        Map<String, Integer> logitBias = getLogitBias();
        int hashCode11 = (hashCode10 * 59) + (logitBias == null ? 43 : logitBias.hashCode());
        String model = getModel();
        return (hashCode11 * 59) + (model == null ? 43 : model.hashCode());
    }

    @Generated
    public List<ChatMessage> getMessages() {
        return this.messages;
    }

    @Generated
    public String getPrompt() {
        return this.prompt;
    }

    @Generated
    public Double getTemperature() {
        return this.temperature;
    }

    @Generated
    public Double getTopP() {
        return this.topP;
    }

    @Generated
    public Integer getN() {
        return this.n;
    }

    @Generated
    public List<String> getStop() {
        return this.stop;
    }

    @Generated
    public Integer getMaxTokens() {
        return this.maxTokens;
    }

    @Generated
    public Double getPresencePenalty() {
        return this.presencePenalty;
    }

    @Generated
    public Double getFrequencyPenalty() {
        return this.frequencyPenalty;
    }

    @Generated
    public Map<String, Integer> getLogitBias() {
        return this.logitBias;
    }

    @Generated
    public String getModel() {
        return this.model;
    }

    @Generated
    public ChatCompletion() {
    }
}
