package dev.langchain4j.store.embedding.pgvector;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.pgvector.PGvector;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore.class */
public class PgVectorEmbeddingStore implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(PgVectorEmbeddingStore.class);
    private static final Gson GSON = new Gson();
    private final String host;
    private final Integer port;
    private final String user;
    private final String password;
    private final String database;
    private final String table;

    /* loaded from: input_file:dev/langchain4j/store/embedding/pgvector/PgVectorEmbeddingStore$PgVectorEmbeddingStoreBuilder.class */
    public static class PgVectorEmbeddingStoreBuilder {
        private String host;
        private Integer port;
        private String user;
        private String password;
        private String database;
        private String table;
        private Integer dimension;
        private Boolean useIndex;
        private Integer indexListSize;
        private Boolean createTable;
        private Boolean dropTableFirst;

        PgVectorEmbeddingStoreBuilder() {
        }

        public PgVectorEmbeddingStoreBuilder host(String str) {
            this.host = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder port(Integer num) {
            this.port = num;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder user(String str) {
            this.user = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder password(String str) {
            this.password = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder database(String str) {
            this.database = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder table(String str) {
            this.table = str;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder dimension(Integer num) {
            this.dimension = num;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder useIndex(Boolean bool) {
            this.useIndex = bool;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder indexListSize(Integer num) {
            this.indexListSize = num;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder createTable(Boolean bool) {
            this.createTable = bool;
            return this;
        }

        public PgVectorEmbeddingStoreBuilder dropTableFirst(Boolean bool) {
            this.dropTableFirst = bool;
            return this;
        }

        public PgVectorEmbeddingStore build() {
            return new PgVectorEmbeddingStore(this.host, this.port, this.user, this.password, this.database, this.table, this.dimension, this.useIndex, this.indexListSize, this.createTable, this.dropTableFirst);
        }

        public String toString() {
            return "PgVectorEmbeddingStore.PgVectorEmbeddingStoreBuilder(host=" + this.host + ", port=" + this.port + ", user=" + this.user + ", password=" + this.password + ", database=" + this.database + ", table=" + this.table + ", dimension=" + this.dimension + ", useIndex=" + this.useIndex + ", indexListSize=" + this.indexListSize + ", createTable=" + this.createTable + ", dropTableFirst=" + this.dropTableFirst + ")";
        }
    }

    public PgVectorEmbeddingStore(String str, Integer num, String str2, String str3, String str4, String str5, Integer num2, Boolean bool, Integer num3, Boolean bool2, Boolean bool3) {
        this.host = ValidationUtils.ensureNotBlank(str, "host");
        this.port = Integer.valueOf(ValidationUtils.ensureGreaterThanZero(num, "port"));
        this.user = ValidationUtils.ensureNotBlank(str2, "user");
        this.password = ValidationUtils.ensureNotBlank(str3, "password");
        this.database = ValidationUtils.ensureNotBlank(str4, "database");
        this.table = ValidationUtils.ensureNotBlank(str5, "table");
        Boolean bool4 = (Boolean) Utils.getOrDefault(bool, false);
        Boolean bool5 = (Boolean) Utils.getOrDefault(bool2, true);
        Boolean bool6 = (Boolean) Utils.getOrDefault(bool3, false);
        try {
            Connection connection = setupConnection();
            Throwable th = null;
            try {
                try {
                    if (bool6.booleanValue()) {
                        connection.createStatement().executeUpdate(String.format("DROP TABLE IF EXISTS %s", str5));
                    }
                    if (bool5.booleanValue()) {
                        connection.createStatement().executeUpdate(String.format("CREATE TABLE IF NOT EXISTS %s (embedding_id UUID PRIMARY KEY, embedding vector(%s), text TEXT NULL, metadata JSON NULL)", str5, Integer.valueOf(ValidationUtils.ensureGreaterThanZero(num2, "dimension"))));
                    }
                    if (bool4.booleanValue()) {
                        connection.createStatement().executeUpdate(String.format("CREATE INDEX IF NOT EXISTS %s ON %s USING ivfflat (embedding vector_cosine_ops) WITH (lists = %s)", str5 + "_ivfflat_index", str5, Integer.valueOf(ValidationUtils.ensureGreaterThanZero(num3, "indexListSize"))));
                    }
                    if (connection != null) {
                        if (0 != 0) {
                            try {
                                connection.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            connection.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    private Connection setupConnection() throws SQLException {
        Connection connection = DriverManager.getConnection(String.format("jdbc:postgresql://%s:%s/%s", this.host, this.port, this.database), this.user, this.password);
        connection.createStatement().executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
        PGvector.addVectorType(connection);
        return connection;
    }

    public String add(Embedding embedding) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, null);
        return randomUUID;
    }

    public void add(String str, Embedding embedding) {
        addInternal(str, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String randomUUID = Utils.randomUUID();
        addInternal(randomUUID, embedding, textSegment);
        return randomUUID;
    }

    public List<String> addAll(List<Embedding> list) {
        List<String> list2 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list2, list, null);
        return list2;
    }

    public List<String> addAll(List<Embedding> list, List<TextSegment> list2) {
        List<String> list3 = (List) list.stream().map(embedding -> {
            return Utils.randomUUID();
        }).collect(Collectors.toList());
        addAllInternal(list3, list, list2);
        return list3;
    }

    /* JADX WARN: Type inference failed for: r0v50, types: [dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore$1] */
    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding embedding, int i, double d) {
        ArrayList arrayList = new ArrayList();
        try {
            Connection connection = setupConnection();
            Throwable th = null;
            try {
                try {
                    ResultSet executeQuery = connection.prepareStatement(String.format("WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, metadata FROM %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;", Arrays.toString(embedding.vector()), this.table, Double.valueOf(d), Integer.valueOf(i))).executeQuery();
                    while (executeQuery.next()) {
                        double d2 = executeQuery.getDouble("score");
                        String string = executeQuery.getString("embedding_id");
                        Embedding embedding2 = new Embedding(((PGvector) executeQuery.getObject("embedding")).toArray());
                        String string2 = executeQuery.getString("text");
                        TextSegment textSegment = null;
                        if (Utils.isNotNullOrBlank(string2)) {
                            textSegment = TextSegment.from(string2, new Metadata(new HashMap((Map) GSON.fromJson((String) Optional.ofNullable(executeQuery.getString("metadata")).orElse("{}"), new TypeToken<Map<String, String>>() { // from class: dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore.1
                            }.getType()))));
                        }
                        arrayList.add(new EmbeddingMatch(Double.valueOf(d2), string, embedding2, textSegment));
                    }
                    if (connection != null) {
                        if (0 != 0) {
                            try {
                                connection.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            connection.close();
                        }
                    }
                    return arrayList;
                } finally {
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    private void addInternal(String str, Embedding embedding, TextSegment textSegment) {
        addAllInternal(Collections.singletonList(str), Collections.singletonList(embedding), textSegment == null ? null : Collections.singletonList(textSegment));
    }

    private void addAllInternal(List<String> list, List<Embedding> list2, List<TextSegment> list3) {
        if (Utils.isNullOrEmpty(list) || Utils.isNullOrEmpty(list2)) {
            log.info("Empty embeddings - no ops");
            return;
        }
        ValidationUtils.ensureTrue(list.size() == list2.size(), "ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue(list3 == null || list2.size() == list3.size(), "embeddings size is not equal to embedded size");
        try {
            Connection connection = setupConnection();
            Throwable th = null;
            try {
                PreparedStatement prepareStatement = connection.prepareStatement(String.format("INSERT INTO %s (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?)ON CONFLICT (embedding_id) DO UPDATE SET embedding = EXCLUDED.embedding,text = EXCLUDED.text,metadata = EXCLUDED.metadata;", this.table));
                for (int i = 0; i < list.size(); i++) {
                    prepareStatement.setObject(1, UUID.fromString(list.get(i)));
                    prepareStatement.setObject(2, new PGvector(list2.get(i).vector()));
                    if (list3 == null || list3.get(i) == null) {
                        prepareStatement.setNull(3, 12);
                        prepareStatement.setNull(4, 1111);
                    } else {
                        prepareStatement.setObject(3, list3.get(i).text());
                        prepareStatement.setObject(4, GSON.toJson(list3.get(i).metadata().asMap()), 1111);
                    }
                    prepareStatement.addBatch();
                }
                prepareStatement.executeBatch();
                if (connection != null) {
                    if (0 != 0) {
                        try {
                            connection.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        connection.close();
                    }
                }
            } finally {
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    public static PgVectorEmbeddingStoreBuilder builder() {
        return new PgVectorEmbeddingStoreBuilder();
    }
}
