package io.moderne.ai;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Function;
import java.util.stream.Stream;
import kong.unirest.HttpResponse;
import kong.unirest.Unirest;
import kong.unirest.UnirestException;
import org.openrewrite.internal.lang.Nullable;

/* loaded from: input_file:io/moderne/ai/EmbeddingModelClient.class */
public class EmbeddingModelClient {
    private static final ExecutorService EXECUTOR_SERVICE = Executors.newFixedThreadPool(3);
    private static final Path MODELS_DIR = Paths.get(System.getProperty("user.home") + "/.moderne/models", new String[0]);

    @Nullable
    private static EmbeddingModelClient INSTANCE;
    private final Map<String, float[]> embeddingCache = Collections.synchronizedMap(new LinkedHashMap<String, float[]>() { // from class: io.moderne.ai.EmbeddingModelClient.1
        @Override // java.util.LinkedHashMap
        protected boolean removeEldestEntry(Map.Entry<String, float[]> entry) {
            return size() > 1000;
        }
    });

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/moderne/ai/EmbeddingModelClient$GradioRequest.class */
    public static class GradioRequest {
        private final String[] data;

        GradioRequest(String... strArr) {
            this.data = strArr;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/moderne/ai/EmbeddingModelClient$GradioResponse.class */
    public static final class GradioResponse {
        private final List<String> data;

        public float[] getEmbedding() {
            String str = this.data.get(0);
            String[] split = str.substring(1, str.length() - 1).split(",");
            float[] fArr = new float[split.length];
            for (int i = 0; i < split.length; i++) {
                fArr[i] = Float.parseFloat(split[i]);
            }
            return fArr;
        }

        public GradioResponse(List<String> list) {
            this.data = list;
        }

        public List<String> getData() {
            return this.data;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof GradioResponse)) {
                return false;
            }
            List<String> data = getData();
            List<String> data2 = ((GradioResponse) obj).getData();
            return data == null ? data2 == null : data.equals(data2);
        }

        public int hashCode() {
            List<String> data = getData();
            return (1 * 59) + (data == null ? 43 : data.hashCode());
        }

        public String toString() {
            return "EmbeddingModelClient.GradioResponse(data=" + getData() + ")";
        }
    }

    /* loaded from: input_file:io/moderne/ai/EmbeddingModelClient$Relatedness.class */
    public static final class Relatedness {
        private final boolean isRelated;
        private final List<Duration> embeddingTimings;

        public Relatedness(boolean z, List<Duration> list) {
            this.isRelated = z;
            this.embeddingTimings = list;
        }

        public boolean isRelated() {
            return this.isRelated;
        }

        public List<Duration> getEmbeddingTimings() {
            return this.embeddingTimings;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Relatedness)) {
                return false;
            }
            Relatedness relatedness = (Relatedness) obj;
            if (isRelated() != relatedness.isRelated()) {
                return false;
            }
            List<Duration> embeddingTimings = getEmbeddingTimings();
            List<Duration> embeddingTimings2 = relatedness.getEmbeddingTimings();
            return embeddingTimings == null ? embeddingTimings2 == null : embeddingTimings.equals(embeddingTimings2);
        }

        public int hashCode() {
            int i = (1 * 59) + (isRelated() ? 79 : 97);
            List<Duration> embeddingTimings = getEmbeddingTimings();
            return (i * 59) + (embeddingTimings == null ? 43 : embeddingTimings.hashCode());
        }

        public String toString() {
            return "EmbeddingModelClient.Relatedness(isRelated=" + isRelated() + ", embeddingTimings=" + getEmbeddingTimings() + ")";
        }
    }

    public static synchronized EmbeddingModelClient getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new EmbeddingModelClient();
            if (INSTANCE.checkForUpRequest() != 200) {
                try {
                    Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", String.format("/usr/bin/python3 'import gradio\ngradio.'", MODELS_DIR)});
                    INSTANCE.start();
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        }
        return INSTANCE;
    }

    private void start() {
        try {
            Files.copy((InputStream) Objects.requireNonNull(EmbeddingModelClient.class.getResourceAsStream("/get_is_related.py")), MODELS_DIR.resolve("get_is_related.py"), StandardCopyOption.REPLACE_EXISTING);
            StringWriter stringWriter = new StringWriter();
            PrintWriter printWriter = new PrintWriter(stringWriter);
            Process exec = Runtime.getRuntime().exec(new String[]{"/bin/sh", "-c", String.format("/usr/bin/python3 %s/get_is_related.py", MODELS_DIR)});
            EXECUTOR_SERVICE.submit(() -> {
                Stream<String> lines = new BufferedReader(new InputStreamReader(exec.getInputStream())).lines();
                Objects.requireNonNull(printWriter);
                lines.forEach(printWriter::println);
                Stream<String> lines2 = new BufferedReader(new InputStreamReader(exec.getErrorStream())).lines();
                Objects.requireNonNull(printWriter);
                lines2.forEach(printWriter::println);
            });
            if (checkForUp(exec)) {
            } else {
                throw new IllegalStateException("Unable to start model daemon. Output of process is:\n" + stringWriter);
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private boolean checkForUp(Process process) {
        for (int i = 0; i < 60; i++) {
            try {
                if (!process.isAlive() && process.exitValue() != 0) {
                    return false;
                }
                if (checkForUpRequest() == 200) {
                    return true;
                }
                Thread.sleep(1000L);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        return false;
    }

    private int checkForUpRequest() {
        try {
            return Unirest.head("http://127.0.0.1:7860").asString().getStatus();
        } catch (UnirestException e) {
            return 523;
        }
    }

    public boolean isRelated(String str, String str2, double d) {
        return dist(this.embeddingCache.computeIfAbsent(str, this::getEmbedding), this.embeddingCache.computeIfAbsent(str2.replace("\n", ""), this::getEmbedding)) <= d;
    }

    public Relatedness getRelatedness(String str, String str2, double d) {
        ArrayList arrayList = new ArrayList(2);
        return new Relatedness(dist(this.embeddingCache.computeIfAbsent(str, timeEmbedding(arrayList)), this.embeddingCache.computeIfAbsent(str2.replace("\n", ""), timeEmbedding(arrayList))) <= d, arrayList);
    }

    private Function<String, float[]> timeEmbedding(List<Duration> list) {
        return str -> {
            long nanoTime = System.nanoTime();
            float[] embedding = getEmbedding(str);
            if (list.isEmpty()) {
                list.add(Duration.ofNanos(System.nanoTime() - nanoTime));
            }
            return embedding;
        };
    }

    private static double dist(float[] fArr, float[] fArr2) {
        if (fArr.length != fArr2.length) {
            throw new IllegalArgumentException("Vectors must have the same dimension");
        }
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            float f2 = fArr[i] - fArr2[i];
            f += f2 * f2;
        }
        return 1.0d - Math.sqrt(f);
    }

    public float[] getEmbedding(String str) {
        HttpResponse asObject = Unirest.post("http://127.0.0.1:7860/run/predict").header("Content-Type", "application/json").body(new GradioRequest(str)).asObject(GradioResponse.class);
        if (asObject.isSuccess()) {
            return ((GradioResponse) asObject.getBody()).getEmbedding();
        }
        throw new IllegalStateException("Unable to get embedding. HTTP " + asObject.getStatus());
    }

    static {
        if (!Files.exists(MODELS_DIR, new LinkOption[0]) && !MODELS_DIR.toFile().mkdirs()) {
            throw new IllegalStateException("Unable to create models directory at " + MODELS_DIR);
        }
    }
}
