/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.embedding;

import ai.vespa.embedding.EmbeddingNormalizer;
import ai.vespa.embedding.config.VoyageAiEmbedderConfig;
import ai.vespa.secret.Secret;
import ai.vespa.secret.Secrets;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.IOException;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import okhttp3.ConnectionPool;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;

@Beta
public class VoyageAIEmbedder
extends AbstractComponent
implements Embedder {
    private static final Logger log = Logger.getLogger(VoyageAIEmbedder.class.getName());
    private static final MediaType JSON = MediaType.get((String)"application/json; charset=utf-8");
    private static final ObjectMapper objectMapper = new ObjectMapper();
    private final VoyageAiEmbedderConfig config;
    private final Embedder.Runtime runtime;
    private final Secret apiKey;
    private final OkHttpClient httpClient;

    @Inject
    public VoyageAIEmbedder(VoyageAiEmbedderConfig config, Embedder.Runtime runtime, Secrets secretStore) {
        this.config = config;
        this.runtime = runtime;
        this.apiKey = this.getApiKey(config, secretStore);
        this.httpClient = this.createHttpClient(config);
        log.info("VoyageAI embedder initialized with model: " + config.model());
    }

    private Secret getApiKey(VoyageAiEmbedderConfig config, Secrets secretStore) {
        String secretName = config.apiKeySecretRef();
        if (secretName == null || secretName.isEmpty()) {
            throw new IllegalArgumentException("api-key-secret-ref must be configured for VoyageAI embedder. Please set it in services.xml and ensure the secret is in the secret store.");
        }
        try {
            Secret secret = secretStore.get(secretName);
            if (secret == null) {
                throw new IllegalArgumentException("Secret not found in secret store: " + secretName + ". Please add it using: vespa secret add " + secretName + " --value YOUR_API_KEY");
            }
            return secret;
        }
        catch (UnsupportedOperationException e) {
            throw new IllegalArgumentException("Secret store is not configured. Cannot retrieve API key for VoyageAI embedder.", e);
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Failed to retrieve API key from secret store. Secret name: " + secretName, e);
        }
    }

    private OkHttpClient createHttpClient(VoyageAiEmbedderConfig config) {
        return new OkHttpClient.Builder().connectTimeout(Duration.ofMillis(config.timeout())).readTimeout(Duration.ofMillis(config.timeout())).writeTimeout(Duration.ofMillis(config.timeout())).callTimeout(Duration.ofMillis(config.timeout())).connectionPool(new ConnectionPool(config.maxIdleConnections(), 5L, TimeUnit.MINUTES)).build();
    }

    public List<Integer> embed(String text, Embedder.Context context) {
        throw new UnsupportedOperationException("VoyageAI embedder only supports embed() with TensorType. Use embed(String text, Context context, TensorType targetType) instead.");
    }

    public Tensor embed(String text, Embedder.Context context, TensorType targetType) {
        this.validateTensorType(targetType);
        long startTime = System.nanoTime();
        try {
            String inputType = this.detectInputType(context);
            CacheKey cacheKey = new CacheKey(context.getEmbedderId(), text, inputType);
            Tensor result = (Tensor)context.computeCachedValueIfAbsent((Object)cacheKey, () -> {
                try {
                    return this.callVoyageAI(text, inputType, targetType);
                }
                catch (IOException | InterruptedException e) {
                    throw new RuntimeException("Failed to call VoyageAI API: " + e.getMessage(), e);
                }
            });
            this.runtime.sampleSequenceLength((long)text.length(), context);
            this.runtime.sampleEmbeddingLatency((double)(System.nanoTime() - startTime) / 1.0E9, context);
            return result;
        }
        catch (RuntimeException e) {
            log.log(Level.WARNING, "VoyageAI embedding failed for model: " + this.config.model(), e);
            throw e;
        }
    }

    private void validateTensorType(TensorType targetType) {
        if (targetType.dimensions().size() != 1) {
            throw new IllegalArgumentException("Error in embedding to type '" + String.valueOf(targetType) + "': should only have one indexed dimension.");
        }
        if (!((TensorType.Dimension)targetType.dimensions().get(0)).isIndexed()) {
            throw new IllegalArgumentException("Error in embedding to type '" + String.valueOf(targetType) + "': dimension should be indexed.");
        }
    }

    private String detectInputType(Embedder.Context context) {
        if (!this.config.autoDetectInputType()) {
            return this.config.defaultInputType().toString().toLowerCase();
        }
        String destination = context.getDestination();
        if (destination != null && destination.toLowerCase().contains("query")) {
            return "query";
        }
        return "document";
    }

    private Tensor callVoyageAI(String text, String inputType, TensorType targetType) throws IOException, InterruptedException {
        VoyageAIRequest request = new VoyageAIRequest(List.of(text), this.config.model(), inputType, this.config.truncate());
        String jsonRequest = objectMapper.writeValueAsString((Object)request);
        log.fine(() -> "VoyageAI request: " + jsonRequest);
        VoyageAIResponse response = this.callAPIWithRetry(jsonRequest);
        if (response.data == null || response.data.isEmpty()) {
            throw new IOException("VoyageAI API returned empty response");
        }
        float[] embedding = response.data.get((int)0).embedding;
        return this.createTensor(embedding, targetType);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private VoyageAIResponse callAPIWithRetry(String jsonRequest) throws IOException, InterruptedException {
        long startTime = System.currentTimeMillis();
        long timeoutMs = this.config.timeout();
        int retries = 0;
        long retryDelay = 1000L;
        try {
            while (true) {
                RequestBody body = RequestBody.create((String)jsonRequest, (MediaType)JSON);
                Request httpRequest = new Request.Builder().url(this.config.endpoint()).header("Authorization", "Bearer " + this.apiKey.current()).header("Content-Type", "application/json").post(body).build();
                Response response = this.httpClient.newCall(httpRequest).execute();
                try {
                    String responseBody;
                    String string = responseBody = response.body() != null ? response.body().string() : "";
                    if (response.isSuccessful()) {
                        VoyageAIResponse voyageAIResponse = (VoyageAIResponse)objectMapper.readValue(responseBody, VoyageAIResponse.class);
                        return voyageAIResponse;
                    }
                    if (response.code() == 429 || response.code() >= 500) {
                        String errorType;
                        long elapsedTime = System.currentTimeMillis() - startTime;
                        long timeRemaining = timeoutMs - elapsedTime;
                        if (timeRemaining <= retryDelay) {
                            errorType = response.code() == 429 ? "rate limit" : "server error";
                            throw new IOException("VoyageAI API " + errorType + " (" + response.code() + "). Cannot retry: would exceed timeout of " + timeoutMs + "ms. Response: " + responseBody);
                        }
                        if (retries >= this.config.maxRetries()) {
                            errorType = response.code() == 429 ? "rate limited" : "server error";
                            throw new IOException("VoyageAI API " + errorType + " (" + response.code() + "). Max retries (" + this.config.maxRetries() + ") exceeded. Response: " + responseBody);
                        }
                        String errorMsg = response.code() == 429 ? "rate limited" : "server error (" + response.code() + ")";
                        log.warning("VoyageAI API " + errorMsg + ". Retry " + ++retries + " after " + retryDelay + "ms (timeout remaining: " + timeRemaining + "ms)");
                        Thread.sleep(retryDelay);
                        continue;
                    }
                    if (response.code() != 401) throw new IOException("VoyageAI API request failed with status " + response.code() + ": " + responseBody);
                    throw new IOException("VoyageAI API authentication failed. Please check your API key. Response: " + responseBody);
                }
                finally {
                    if (response == null) continue;
                    response.close();
                    continue;
                }
                break;
            }
        }
        catch (JsonProcessingException e) {
            throw new IOException("Failed to parse VoyageAI API response", e);
        }
    }

    private Tensor createTensor(float[] embedding, TensorType targetType) {
        long expectedDim = ((TensorType.Dimension)targetType.dimensions().get(0)).size().orElse(-1L);
        if (expectedDim != -1L && (long)embedding.length != expectedDim) {
            throw new IllegalArgumentException("VoyageAI returned " + embedding.length + " dimensions but target type expects " + expectedDim + ". Please ensure the model '" + this.config.model() + "' outputs the correct dimensions.");
        }
        TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.FLOAT);
        typeBuilder.indexed(((TensorType.Dimension)targetType.dimensions().get(0)).name(), (long)embedding.length);
        TensorType type = typeBuilder.build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of((TensorType)type);
        for (int i = 0; i < embedding.length; ++i) {
            builder.cell(embedding[i], new long[]{i});
        }
        IndexedTensor result = builder.build();
        if (this.config.normalize()) {
            result = EmbeddingNormalizer.normalize((Tensor)result, type);
        }
        return result;
    }

    public void deconstruct() {
        this.httpClient.dispatcher().executorService().shutdown();
        this.httpClient.connectionPool().evictAll();
        super.deconstruct();
    }

    private record CacheKey(String embedderId, String text, String inputType) {
    }

    private record VoyageAIRequest(@JsonProperty(value="input") List<String> input, @JsonProperty(value="model") String model, @JsonProperty(value="input_type") String inputType, @JsonProperty(value="truncation") boolean truncation) {
    }

    @JsonIgnoreProperties(ignoreUnknown=true)
    private static class VoyageAIResponse {
        @JsonProperty(value="data")
        public List<EmbeddingData> data;
        @JsonProperty(value="model")
        public String model;
        @JsonProperty(value="usage")
        public Usage usage;

        private VoyageAIResponse() {
        }
    }

    @JsonIgnoreProperties(ignoreUnknown=true)
    private static class EmbeddingData {
        @JsonProperty(value="embedding")
        public float[] embedding;
        @JsonProperty(value="index")
        public int index;

        private EmbeddingData() {
        }
    }

    @JsonIgnoreProperties(ignoreUnknown=true)
    private static class Usage {
        @JsonProperty(value="total_tokens")
        public int totalTokens;

        private Usage() {
        }
    }
}

