/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.watsonx;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.exception.UnsupportedFeatureException;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.StreamingChatModel;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.watsonx.Watsonx;
import io.quarkiverse.langchain4j.watsonx.WatsonxGenerationRequestParameters;
import io.quarkiverse.langchain4j.watsonx.WatsonxUtils;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationParameters;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse;
import io.smallrye.mutiny.Context;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

public class WatsonxGenerationStreamingModel
extends Watsonx
implements StreamingChatModel {
    private static final String INPUT_TOKEN_COUNT_CONTEXT = "INPUT_TOKEN_COUNT";
    private static final String GENERATED_TOKEN_COUNT_CONTEXT = "GENERATED_TOKEN_COUNT";
    private static final String COMPLETE_MESSAGE_CONTEXT = "COMPLETE_MESSAGE";
    private static final String FINISH_REASON_CONTEXT = "FINISH_REASON";
    private static final String MODEL_ID_CONTEXT = "MODEL_ID";
    private final WatsonxGenerationRequestParameters defaultRequestParameters;
    private final String promptJoiner;

    public WatsonxGenerationStreamingModel(Builder builder) {
        super(builder);
        TextGenerationParameters.LengthPenalty lengthPenalty = null;
        if (Objects.nonNull(builder.decayFactor) || Objects.nonNull(builder.startIndex)) {
            lengthPenalty = new TextGenerationParameters.LengthPenalty(builder.decayFactor, builder.startIndex);
        }
        this.promptJoiner = builder.promptJoiner;
        this.defaultRequestParameters = ((WatsonxGenerationRequestParameters.Builder)((WatsonxGenerationRequestParameters.Builder)((WatsonxGenerationRequestParameters.Builder)((WatsonxGenerationRequestParameters.Builder)((WatsonxGenerationRequestParameters.Builder)((WatsonxGenerationRequestParameters.Builder)WatsonxGenerationRequestParameters.builder().modelName(builder.modelId)).decodingMethod(builder.decodingMethod).lengthPenalty(lengthPenalty).minNewTokens(builder.minNewTokens).maxOutputTokens(builder.maxNewTokens)).randomSeed(builder.randomSeed).stopSequences(builder.stopSequences)).temperature(builder.temperature)).timeLimit(builder.timeout).topP(builder.topP)).topK(builder.topK)).repetitionPenalty(builder.repetitionPenalty).truncateInputTokens(builder.truncateInputTokens).includeStopSequence(builder.includeStopSequence).build();
    }

    public void doChat(ChatRequest chatRequest, final StreamingChatResponseHandler handler) {
        String modelId = chatRequest.parameters().modelName();
        ChatRequestParameters parameters = chatRequest.parameters();
        this.validate(parameters);
        TextGenerationRequest request = new TextGenerationRequest(modelId, this.spaceId, this.projectId, this.toInput(chatRequest.messages()), TextGenerationParameters.convert(parameters));
        final Context context = Context.empty();
        context.put(COMPLETE_MESSAGE_CONTEXT, (Object)new StringBuilder());
        context.put(INPUT_TOKEN_COUNT_CONTEXT, (Object)0);
        context.put(GENERATED_TOKEN_COUNT_CONTEXT, (Object)0);
        this.client.generationStreaming(request, this.version).onFailure(WatsonxUtils::isTokenExpired).retry().atMost(1L).subscribe().with(context, (Consumer)new Consumer<TextGenerationResponse>(){
            final /* synthetic */ WatsonxGenerationStreamingModel this$0;
            {
                this.this$0 = this$0;
            }

            @Override
            public void accept(TextGenerationResponse response) {
                try {
                    if (response == null || response.results() == null || response.results().isEmpty()) {
                        return;
                    }
                    StringBuilder stringBuilder = (StringBuilder)context.get(WatsonxGenerationStreamingModel.COMPLETE_MESSAGE_CONTEXT);
                    TextGenerationResponse.Result chunk = response.results().get(0);
                    if (!context.contains(WatsonxGenerationStreamingModel.MODEL_ID_CONTEXT) && response.modelId() != null) {
                        context.put(WatsonxGenerationStreamingModel.MODEL_ID_CONTEXT, (Object)response.modelId());
                    }
                    if (!chunk.stopReason().equals("not_finished")) {
                        context.put(WatsonxGenerationStreamingModel.FINISH_REASON_CONTEXT, (Object)chunk.stopReason());
                    }
                    int inputTokenCount = (Integer)context.get(WatsonxGenerationStreamingModel.INPUT_TOKEN_COUNT_CONTEXT);
                    context.put(WatsonxGenerationStreamingModel.INPUT_TOKEN_COUNT_CONTEXT, (Object)(inputTokenCount + chunk.inputTokenCount()));
                    int generatedTokenCount = (Integer)context.get(WatsonxGenerationStreamingModel.GENERATED_TOKEN_COUNT_CONTEXT);
                    context.put(WatsonxGenerationStreamingModel.GENERATED_TOKEN_COUNT_CONTEXT, (Object)(generatedTokenCount + chunk.generatedTokenCount()));
                    stringBuilder.append(chunk.generatedText());
                    handler.onPartialResponse(chunk.generatedText());
                }
                catch (Exception e) {
                    handler.onError((Throwable)e);
                }
            }
        }, (Consumer)new Consumer<Throwable>(){
            final /* synthetic */ WatsonxGenerationStreamingModel this$0;
            {
                this.this$0 = this$0;
            }

            @Override
            public void accept(Throwable error) {
                handler.onError(error);
            }
        }, new Runnable(){
            final /* synthetic */ WatsonxGenerationStreamingModel this$0;
            {
                this.this$0 = this$0;
            }

            @Override
            public void run() {
                StringBuilder response = (StringBuilder)context.get(WatsonxGenerationStreamingModel.COMPLETE_MESSAGE_CONTEXT);
                FinishReason finishReason = context.contains(WatsonxGenerationStreamingModel.FINISH_REASON_CONTEXT) ? this.this$0.toFinishReason((String)context.get(WatsonxGenerationStreamingModel.FINISH_REASON_CONTEXT)) : null;
                int inputTokenCount = context.contains(WatsonxGenerationStreamingModel.INPUT_TOKEN_COUNT_CONTEXT) ? (Integer)context.get(WatsonxGenerationStreamingModel.INPUT_TOKEN_COUNT_CONTEXT) : 0;
                int outputTokenCount = context.contains(WatsonxGenerationStreamingModel.GENERATED_TOKEN_COUNT_CONTEXT) ? (Integer)context.get(WatsonxGenerationStreamingModel.GENERATED_TOKEN_COUNT_CONTEXT) : 0;
                String modelId = context.contains(WatsonxGenerationStreamingModel.MODEL_ID_CONTEXT) ? (String)context.get(WatsonxGenerationStreamingModel.MODEL_ID_CONTEXT) : null;
                AiMessage aiMessage = AiMessage.from((String)response.toString());
                TokenUsage tokenUsage = new TokenUsage(Integer.valueOf(inputTokenCount), Integer.valueOf(outputTokenCount));
                ChatResponse chatResponse = ChatResponse.builder().aiMessage(aiMessage).metadata(ChatResponseMetadata.builder().modelName(modelId).tokenUsage(tokenUsage).finishReason(finishReason).build()).build();
                handler.onCompleteResponse(chatResponse);
            }
        });
    }

    public static Builder builder() {
        return new Builder();
    }

    public List<ChatModelListener> listeners() {
        return this.listeners;
    }

    public ChatRequestParameters defaultRequestParameters() {
        return this.defaultRequestParameters;
    }

    public Set<Capability> supportedCapabilities() {
        return Set.of();
    }

    private void validate(ChatRequestParameters parameters) throws UnsupportedFeatureException {
        if (parameters.frequencyPenalty() != null) {
            throw new UnsupportedFeatureException("'frequencyPenalty' parameter is not supported.");
        }
        if (parameters.presencePenalty() != null) {
            throw new UnsupportedFeatureException("'presencePenalty' parameter is not supported.");
        }
        if (parameters.toolChoice() != null) {
            throw new UnsupportedFeatureException("'toolChoice' parameter is not supported.");
        }
        if (parameters.responseFormat() != null) {
            throw new UnsupportedFeatureException("'responseFormat' parameter is not supported.");
        }
    }

    private String toInput(List<ChatMessage> messages) {
        return messages.stream().map(new Function<ChatMessage, String>(){

            @Override
            public String apply(ChatMessage chatMessage) {
                return switch (chatMessage.type()) {
                    case ChatMessageType.AI -> {
                        AiMessage aiMessage = (AiMessage)chatMessage;
                        yield aiMessage.text();
                    }
                    case ChatMessageType.SYSTEM -> {
                        SystemMessage systemMessage = (SystemMessage)chatMessage;
                        yield systemMessage.text();
                    }
                    case ChatMessageType.USER -> {
                        UserMessage userMessage = (UserMessage)chatMessage;
                        if (userMessage.hasSingleText()) {
                            yield userMessage.singleText();
                        }
                        throw new RuntimeException("For the generation model, the UserMessage can contain only a single text");
                    }
                    case ChatMessageType.TOOL_EXECUTION_RESULT -> throw new RuntimeException("The generation model doesn't allow the use of tools");
                    default -> throw new RuntimeException("Unsupported chat message type: " + String.valueOf(chatMessage.type()));
                };
            }
        }).collect(Collectors.joining(this.promptJoiner));
    }

    private FinishReason toFinishReason(String reason) {
        return switch (reason) {
            case "max_tokens", "token_limit" -> FinishReason.LENGTH;
            case "eos_token", "stop_sequence" -> FinishReason.STOP;
            case "not_finished", "cancelled", "time_limit", "error" -> FinishReason.OTHER;
            default -> throw new IllegalArgumentException("%s not supported".formatted(reason));
        };
    }

    public static final class Builder
    extends Watsonx.Builder<Builder> {
        private String decodingMethod;
        private Double decayFactor;
        private Integer startIndex;
        private Integer maxNewTokens;
        private Integer minNewTokens;
        private Integer randomSeed;
        private List<String> stopSequences;
        private Double temperature;
        private Integer topK;
        private Double topP;
        private Double repetitionPenalty;
        private Integer truncateInputTokens;
        private Boolean includeStopSequence;
        private String promptJoiner;

        public Builder decodingMethod(String decodingMethod) {
            this.decodingMethod = decodingMethod;
            return this;
        }

        public Builder decayFactor(Double decayFactor) {
            this.decayFactor = decayFactor;
            return this;
        }

        public Builder startIndex(Integer startIndex) {
            this.startIndex = startIndex;
            return this;
        }

        public Builder minNewTokens(Integer minNewTokens) {
            this.minNewTokens = minNewTokens;
            return this;
        }

        public Builder maxNewTokens(Integer maxNewTokens) {
            this.maxNewTokens = maxNewTokens;
            return this;
        }

        public Builder temperature(Double temperature) {
            this.temperature = temperature;
            return this;
        }

        public Builder topK(Integer topK) {
            this.topK = topK;
            return this;
        }

        public Builder topP(Double topP) {
            this.topP = topP;
            return this;
        }

        public Builder randomSeed(Integer randomSeed) {
            this.randomSeed = randomSeed;
            return this;
        }

        public Builder repetitionPenalty(Double repetitionPenalty) {
            this.repetitionPenalty = repetitionPenalty;
            return this;
        }

        public Builder stopSequences(List<String> stopSequences) {
            this.stopSequences = stopSequences;
            return this;
        }

        public Builder truncateInputTokens(Integer truncateInputTokens) {
            this.truncateInputTokens = truncateInputTokens;
            return this;
        }

        public Builder includeStopSequence(Boolean includeStopSequence) {
            this.includeStopSequence = includeStopSequence;
            return this;
        }

        public Builder promptJoiner(String promptJoiner) {
            this.promptJoiner = promptJoiner;
            return this;
        }

        public WatsonxGenerationStreamingModel build() {
            return new WatsonxGenerationStreamingModel(this);
        }
    }
}

