/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.deployment.devservice;

import io.quarkiverse.langchain4j.deployment.config.LangChain4jBuildConfig;
import io.quarkiverse.langchain4j.deployment.devservice.Langchain4jDevServicesEnabled;
import io.quarkiverse.langchain4j.deployment.devservice.OllamaClient;
import io.quarkiverse.langchain4j.deployment.items.DevServicesChatModelRequiredBuildItem;
import io.quarkiverse.langchain4j.deployment.items.DevServicesEmbeddingModelRequiredBuildItem;
import io.quarkiverse.langchain4j.deployment.items.DevServicesModelRequired;
import io.quarkiverse.langchain4j.deployment.items.DevServicesOllamaConfigBuildItem;
import io.quarkus.builder.item.BuildItem;
import io.quarkus.builder.item.MultiBuildItem;
import io.quarkus.deployment.IsNormal;
import io.quarkus.deployment.annotations.BuildProducer;
import io.quarkus.deployment.annotations.BuildStep;
import io.quarkus.deployment.annotations.BuildSteps;
import io.quarkus.deployment.builditem.DevServicesResultBuildItem;
import io.quarkus.deployment.builditem.LaunchModeBuildItem;
import io.quarkus.deployment.console.ConsoleInstalledBuildItem;
import io.quarkus.deployment.console.StartupLogCompressor;
import io.quarkus.deployment.logging.LoggingSetupBuildItem;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Flow;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.jboss.logging.Logger;

@BuildSteps(onlyIfNot={IsNormal.class})
public class DevServicesOllamaProcessor {
    private static final Logger LOGGER = Logger.getLogger(DevServicesOllamaProcessor.class);
    private static final String OLLAMA_PROVIDER = "ollama";

    @BuildStep(onlyIfNot={IsNormal.class}, onlyIf={Langchain4jDevServicesEnabled.class})
    private void handleModels(List<DevServicesChatModelRequiredBuildItem> devServicesChatModels, List<DevServicesEmbeddingModelRequiredBuildItem> devServicesEmbeddingModels, LoggingSetupBuildItem loggingSetupBuildItem, Optional<ConsoleInstalledBuildItem> consoleInstalledBuildItem, Optional<DevServicesOllamaConfigBuildItem> ollamaDevServicesConfig, LaunchModeBuildItem launchMode, LangChain4jBuildConfig config, BuildProducer<DevServicesResultBuildItem> producer) {
        if (devServicesChatModels.isEmpty() && devServicesEmbeddingModels.isEmpty()) {
            return;
        }
        List<DevServicesChatModelRequiredBuildItem> ollamaChatModels = devServicesChatModels.stream().filter(bi -> OLLAMA_PROVIDER.equals(bi.getProvider())).toList();
        List<DevServicesEmbeddingModelRequiredBuildItem> ollamaEmbeddingModels = devServicesEmbeddingModels.stream().filter(bi -> OLLAMA_PROVIDER.equals(bi.getProvider())).toList();
        LinkedHashSet<MultiBuildItem> allOllamaModels = new LinkedHashSet<MultiBuildItem>();
        allOllamaModels.addAll(ollamaChatModels);
        allOllamaModels.addAll(ollamaEmbeddingModels);
        if (allOllamaModels.isEmpty()) {
            return;
        }
        String devServiceHost = ollamaDevServicesConfig.map(c -> c.getConfig().get("langchain4j-ollama-dev-service.ollama.host")).orElse("localhost");
        Integer devServicePort = ollamaDevServicesConfig.map(c -> c.getConfig().get("langchain4j-ollama-dev-service.ollama.port")).map(Integer::parseInt).orElseGet(() -> config.devservices().port());
        OllamaClient client = OllamaClient.create(new OllamaClient.Options(devServiceHost, devServicePort));
        try {
            Set localModels = client.localModels().stream().map(mi -> ModelName.of(mi.name())).collect(Collectors.toSet());
            ArrayList<String> modelsToPull = new ArrayList<String>(allOllamaModels.size());
            for (DevServicesModelRequired devServicesModelRequired : allOllamaModels) {
                if (localModels.contains(ModelName.of(devServicesModelRequired.getModelName()))) {
                    LOGGER.debug((Object)("Ollama already has model " + devServicesModelRequired.getModelName() + " pulled locally"));
                    continue;
                }
                modelsToPull.add(devServicesModelRequired.getModelName());
            }
            LOGGER.debug((Object)("Need to pull the following models into Ollama server: " + String.join((CharSequence)", ", modelsToPull)));
            final AtomicReference clientThreadName = new AtomicReference();
            StartupLogCompressor startupLogCompressor = new StartupLogCompressor((launchMode.isTest() ? "(test) " : "") + "Ollama model pull:", consoleInstalledBuildItem, loggingSetupBuildItem, thread -> {
                String t = (String)clientThreadName.get();
                if (t == null) {
                    return false;
                }
                return thread.getName().equals(t);
            });
            for (final String model : modelsToPull) {
                LOGGER.infof("Pulling model %s", (Object)model);
                final AtomicReference LAST_UPDATE_REF = new AtomicReference();
                final CompletableFuture completableFuture = new CompletableFuture();
                client.pullAsync(model).subscribe(new Flow.Subscriber<OllamaClient.PullAsyncLine>(this){
                    private static final BigDecimal ONE_HUNDRED = new BigDecimal("100");
                    final /* synthetic */ DevServicesOllamaProcessor this$0;
                    {
                        this.this$0 = this$0;
                    }

                    @Override
                    public void onSubscribe(Flow.Subscription subscription) {
                        subscription.request(Long.MAX_VALUE);
                    }

                    @Override
                    public void onNext(OllamaClient.PullAsyncLine line) {
                        clientThreadName.compareAndSet(null, Thread.currentThread().getName());
                        if (line.total() != null && line.completed() != null && line.status() != null && line.status().contains("pulling")) {
                            if (!this.logUpdate((Long)LAST_UPDATE_REF.get())) {
                                return;
                            }
                            LAST_UPDATE_REF.set(System.nanoTime());
                            BigDecimal percentage = new BigDecimal(line.completed()).divide(new BigDecimal(line.total()), 4, RoundingMode.HALF_DOWN).multiply(ONE_HUNDRED);
                            BigDecimal progress = percentage.setScale(2, RoundingMode.HALF_DOWN);
                            if (progress.compareTo(ONE_HUNDRED) >= 0) {
                                LOGGER.info((Object)"Verifying and cleaning up\n");
                            } else {
                                LOGGER.infof("Downloading %s - Progress: %s%%\n", (Object)model, (Object)progress);
                            }
                        }
                    }

                    private boolean logUpdate(Long lastUpdate) {
                        if (lastUpdate == null) {
                            return true;
                        }
                        return TimeUnit.NANOSECONDS.toMillis(System.nanoTime()) - TimeUnit.NANOSECONDS.toMillis(lastUpdate) > 1000L;
                    }

                    @Override
                    public void onError(Throwable throwable) {
                        completableFuture.completeExceptionally(throwable);
                    }

                    @Override
                    public void onComplete() {
                        completableFuture.complete(null);
                    }
                });
                try {
                    completableFuture.get(5L, TimeUnit.MINUTES);
                }
                catch (InterruptedException | ExecutionException | TimeoutException e) {
                    startupLogCompressor.closeAndDumpCaptured();
                    throw new RuntimeException(e.getCause());
                }
            }
            if (ollamaChatModels.size() == 1 && config.devservices().preload()) {
                String modelName = ollamaChatModels.get(0).getModelName();
                LOGGER.infof("Preloading model %s", (Object)modelName);
                client.preloadChatModel(modelName);
            }
            startupLogCompressor.close();
            String ollamaBaseUrl = String.format("http://%s:%d", devServiceHost, devServicePort);
            HashMap<String, String> modelBaseUrls = new HashMap<String, String>();
            for (DevServicesModelRequired devServicesModelRequired : allOllamaModels) {
                modelBaseUrls.put(devServicesModelRequired.getBaseUrlProperty(), ollamaBaseUrl);
            }
            producer.produce((BuildItem)new DevServicesResultBuildItem(OLLAMA_PROVIDER, null, modelBaseUrls));
        }
        catch (OllamaClient.ServerUnavailableException e) {
            LOGGER.warn((Object)(e.getMessage() + " therefore no dev service will be started. Ollama can be installed via https://ollama.com/download"));
            return;
        }
    }

    private record ModelName(String model, String tag) {
        public static ModelName of(String modelName) {
            Objects.requireNonNull(modelName, "modelName cannot be null");
            String[] parts = modelName.split(":");
            if (parts.length == 1) {
                return new ModelName(modelName, "latest");
            }
            if (parts.length == 2) {
                return new ModelName(parts[0], parts[1]);
            }
            throw new IllegalArgumentException("Invalid model name: " + modelName);
        }
    }
}

