/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.model.bedrock.internal.BedrockChatModelResponse;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse>
implements ChatLanguageModel {
    private static final String HUMAN_PROMPT = "Human:";
    private static final String ASSISTANT_PROMPT = "Assistant:";
    private final String humanPrompt;
    private final String assistantPrompt;
    private final Integer maxRetries;
    private final Region region;
    private final AwsCredentialsProvider credentialsProvider;
    private final int maxTokens;
    private final float temperature;
    private final float topP;
    private final String[] stopSequences;
    private final AtomicReference<Object> client = new AtomicReference();

    public Response<AiMessage> generate(List<ChatMessage> messages) {
        String context = messages.stream().filter(message -> message.type() == ChatMessageType.SYSTEM).map(ChatMessage::text).collect(Collectors.joining("\n"));
        String userMessages = messages.stream().filter(message -> message.type() != ChatMessageType.SYSTEM).map(this::chatMessageToString).collect(Collectors.joining("\n"));
        String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
        Map<String, Object> requestParameters = this.getRequestParameters(prompt);
        String body = Json.toJson(requestParameters);
        InvokeModelResponse invokeModelResponse = (InvokeModelResponse)RetryUtils.withRetry(() -> this.invoke(body), (int)this.maxRetries);
        String response = invokeModelResponse.body().asUtf8String();
        BedrockChatModelResponse result = (BedrockChatModelResponse)Json.fromJson((String)response, this.getResponseClassType());
        return new Response((Object)new AiMessage(result.getOutputText()), result.getTokenUsage(), result.getFinishReason());
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        throw new IllegalArgumentException("Tools are currently not supported for Bedrock models");
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
        throw new IllegalArgumentException("Tools are currently not supported for Bedrock models");
    }

    protected String chatMessageToString(ChatMessage message) {
        switch (message.type()) {
            case SYSTEM: {
                return message.text();
            }
            case USER: {
                return this.humanPrompt + " " + message.text();
            }
            case AI: {
                return this.assistantPrompt + " " + message.text();
            }
            case TOOL_EXECUTION_RESULT: {
                throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
            }
        }
        throw new IllegalArgumentException("Unknown message type: " + message.type());
    }

    protected abstract Map<String, Object> getRequestParameters(String var1);

    protected abstract String getModelId();

    protected abstract Class<T> getResponseClassType();

    protected InvokeModelResponse invoke(String body) {
        InvokeModelRequest invokeModelRequest = (InvokeModelRequest)InvokeModelRequest.builder().modelId(this.getModelId()).body(SdkBytes.fromString((String)body, (Charset)Charset.defaultCharset())).build();
        return this.getClient().invokeModel(invokeModelRequest);
    }

    protected static Map<String, Object> of(final String key, final Object value) {
        return new HashMap<String, Object>(1){
            {
                super(x0);
                this.put(key, value);
            }
        };
    }

    private BedrockRuntimeClient initClient() {
        return (BedrockRuntimeClient)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)BedrockRuntimeClient.builder().region(this.region)).credentialsProvider(this.credentialsProvider)).build();
    }

    private static <T extends BedrockChatModelResponse> String $default$humanPrompt() {
        return HUMAN_PROMPT;
    }

    private static <T extends BedrockChatModelResponse> String $default$assistantPrompt() {
        return ASSISTANT_PROMPT;
    }

    private static <T extends BedrockChatModelResponse> Integer $default$maxRetries() {
        return 5;
    }

    private static <T extends BedrockChatModelResponse> Region $default$region() {
        return Region.US_EAST_1;
    }

    private static <T extends BedrockChatModelResponse> AwsCredentialsProvider $default$credentialsProvider() {
        return DefaultCredentialsProvider.builder().build();
    }

    private static <T extends BedrockChatModelResponse> int $default$maxTokens() {
        return 300;
    }

    private static <T extends BedrockChatModelResponse> float $default$temperature() {
        return 1.0f;
    }

    private static <T extends BedrockChatModelResponse> float $default$topP() {
        return 0.999f;
    }

    private static <T extends BedrockChatModelResponse> String[] $default$stopSequences() {
        return new String[0];
    }

    protected AbstractBedrockChatModel(AbstractBedrockChatModelBuilder<T, ?, ?> b) {
        this.humanPrompt = ((AbstractBedrockChatModelBuilder)b).humanPrompt$set ? ((AbstractBedrockChatModelBuilder)b).humanPrompt$value : AbstractBedrockChatModel.$default$humanPrompt();
        this.assistantPrompt = ((AbstractBedrockChatModelBuilder)b).assistantPrompt$set ? ((AbstractBedrockChatModelBuilder)b).assistantPrompt$value : AbstractBedrockChatModel.$default$assistantPrompt();
        this.maxRetries = ((AbstractBedrockChatModelBuilder)b).maxRetries$set ? ((AbstractBedrockChatModelBuilder)b).maxRetries$value : AbstractBedrockChatModel.$default$maxRetries();
        this.region = ((AbstractBedrockChatModelBuilder)b).region$set ? ((AbstractBedrockChatModelBuilder)b).region$value : AbstractBedrockChatModel.$default$region();
        this.credentialsProvider = ((AbstractBedrockChatModelBuilder)b).credentialsProvider$set ? ((AbstractBedrockChatModelBuilder)b).credentialsProvider$value : AbstractBedrockChatModel.$default$credentialsProvider();
        this.maxTokens = ((AbstractBedrockChatModelBuilder)b).maxTokens$set ? ((AbstractBedrockChatModelBuilder)b).maxTokens$value : AbstractBedrockChatModel.$default$maxTokens();
        this.temperature = ((AbstractBedrockChatModelBuilder)b).temperature$set ? ((AbstractBedrockChatModelBuilder)b).temperature$value : AbstractBedrockChatModel.$default$temperature();
        this.topP = ((AbstractBedrockChatModelBuilder)b).topP$set ? ((AbstractBedrockChatModelBuilder)b).topP$value : AbstractBedrockChatModel.$default$topP();
        this.stopSequences = ((AbstractBedrockChatModelBuilder)b).stopSequences$set ? ((AbstractBedrockChatModelBuilder)b).stopSequences$value : AbstractBedrockChatModel.$default$stopSequences();
    }

    public String getHumanPrompt() {
        return this.humanPrompt;
    }

    public String getAssistantPrompt() {
        return this.assistantPrompt;
    }

    public Integer getMaxRetries() {
        return this.maxRetries;
    }

    public Region getRegion() {
        return this.region;
    }

    public AwsCredentialsProvider getCredentialsProvider() {
        return this.credentialsProvider;
    }

    public int getMaxTokens() {
        return this.maxTokens;
    }

    public float getTemperature() {
        return this.temperature;
    }

    public float getTopP() {
        return this.topP;
    }

    public String[] getStopSequences() {
        return this.stopSequences;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public BedrockRuntimeClient getClient() {
        Object value = this.client.get();
        if (value == null) {
            AtomicReference<Object> atomicReference = this.client;
            synchronized (atomicReference) {
                value = this.client.get();
                if (value == null) {
                    BedrockRuntimeClient actualValue = this.initClient();
                    value = actualValue == null ? this.client : actualValue;
                    this.client.set(value);
                }
            }
        }
        return (BedrockRuntimeClient)(value == this.client ? null : value);
    }

    public static abstract class AbstractBedrockChatModelBuilder<T extends BedrockChatModelResponse, C extends AbstractBedrockChatModel<T>, B extends AbstractBedrockChatModelBuilder<T, C, B>> {
        private boolean humanPrompt$set;
        private String humanPrompt$value;
        private boolean assistantPrompt$set;
        private String assistantPrompt$value;
        private boolean maxRetries$set;
        private Integer maxRetries$value;
        private boolean region$set;
        private Region region$value;
        private boolean credentialsProvider$set;
        private AwsCredentialsProvider credentialsProvider$value;
        private boolean maxTokens$set;
        private int maxTokens$value;
        private boolean temperature$set;
        private float temperature$value;
        private boolean topP$set;
        private float topP$value;
        private boolean stopSequences$set;
        private String[] stopSequences$value;

        public B humanPrompt(String humanPrompt) {
            this.humanPrompt$value = humanPrompt;
            this.humanPrompt$set = true;
            return this.self();
        }

        public B assistantPrompt(String assistantPrompt) {
            this.assistantPrompt$value = assistantPrompt;
            this.assistantPrompt$set = true;
            return this.self();
        }

        public B maxRetries(Integer maxRetries) {
            this.maxRetries$value = maxRetries;
            this.maxRetries$set = true;
            return this.self();
        }

        public B region(Region region) {
            this.region$value = region;
            this.region$set = true;
            return this.self();
        }

        public B credentialsProvider(AwsCredentialsProvider credentialsProvider) {
            this.credentialsProvider$value = credentialsProvider;
            this.credentialsProvider$set = true;
            return this.self();
        }

        public B maxTokens(int maxTokens) {
            this.maxTokens$value = maxTokens;
            this.maxTokens$set = true;
            return this.self();
        }

        public B temperature(float temperature) {
            this.temperature$value = temperature;
            this.temperature$set = true;
            return this.self();
        }

        public B topP(float topP) {
            this.topP$value = topP;
            this.topP$set = true;
            return this.self();
        }

        public B stopSequences(String[] stopSequences) {
            this.stopSequences$value = stopSequences;
            this.stopSequences$set = true;
            return this.self();
        }

        protected abstract B self();

        public abstract C build();

        public String toString() {
            return "AbstractBedrockChatModel.AbstractBedrockChatModelBuilder(humanPrompt$value=" + this.humanPrompt$value + ", assistantPrompt$value=" + this.assistantPrompt$value + ", maxRetries$value=" + this.maxRetries$value + ", region$value=" + this.region$value + ", credentialsProvider$value=" + this.credentialsProvider$value + ", maxTokens$value=" + this.maxTokens$value + ", temperature$value=" + this.temperature$value + ", topP$value=" + this.topP$value + ", stopSequences$value=" + Arrays.deepToString(this.stopSequences$value) + ")";
        }
    }
}

