package com.alibaba.cloud.ai.analyticdb;

import com.aliyun.gpdb20160503.Client;
import com.aliyun.gpdb20160503.models.CreateCollectionRequest;
import com.aliyun.gpdb20160503.models.CreateNamespaceRequest;
import com.aliyun.gpdb20160503.models.DeleteCollectionDataRequest;
import com.aliyun.gpdb20160503.models.DescribeCollectionRequest;
import com.aliyun.gpdb20160503.models.DescribeNamespaceRequest;
import com.aliyun.gpdb20160503.models.InitVectorDatabaseRequest;
import com.aliyun.gpdb20160503.models.QueryCollectionDataRequest;
import com.aliyun.gpdb20160503.models.QueryCollectionDataResponse;
import com.aliyun.gpdb20160503.models.QueryCollectionDataResponseBody;
import com.aliyun.gpdb20160503.models.UpsertCollectionDataRequest;
import com.aliyun.tea.TeaException;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.util.JacksonUtils;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:com/alibaba/cloud/ai/analyticdb/AnalyticDbVectorStore.class */
public class AnalyticDbVectorStore extends AbstractObservationVectorStore implements InitializingBean {
    private static final String DATA_BASE_SYSTEM = "analytic_db";
    private static final String REF_DOC_NAME = "refDocId";
    private static final String METADATA_FIELD_NAME = "metadata";
    private static final String CONTENT_FIELD_NAME = "content";
    private static final String DOC_NAME = "docId";
    private static final int DEFAULT_TOP_K = 4;
    public final FilterExpressionConverter filterExpressionConverter;
    private final String collectionName;
    private final AnalyticDbConfig config;
    private final Client client;
    private final ObjectMapper objectMapper;
    private final Integer defaultTopK;
    private final Double defaultSimilarityThreshold;
    private static final Logger logger = LoggerFactory.getLogger(AnalyticDbVectorStore.class);
    private static final Double DEFAULT_SIMILARITY_THRESHOLD = Double.valueOf(0.0d);

    /* loaded from: input_file:com/alibaba/cloud/ai/analyticdb/AnalyticDbVectorStore$Builder.class */
    public static class Builder extends AbstractVectorStoreBuilder<Builder> {
        private final String collectionName;
        private final AnalyticDbConfig config;
        private final Client client;
        private int defaultTopK;
        private Double defaultSimilarityThreshold;

        private Builder(String str, AnalyticDbConfig analyticDbConfig, Client client, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            this.defaultTopK = AnalyticDbVectorStore.DEFAULT_TOP_K;
            this.defaultSimilarityThreshold = AnalyticDbVectorStore.DEFAULT_SIMILARITY_THRESHOLD;
            Assert.notNull(client, "Client must not be null");
            this.client = client;
            Assert.notNull(str, "Collection name must not be null");
            this.collectionName = str.toLowerCase();
            this.config = analyticDbConfig;
        }

        public Builder defaultTopK(int i) {
            Assert.isTrue(i >= 0, "The topK should be positive value.");
            this.defaultTopK = i;
            return this;
        }

        public Builder defaultSimilarityThreshold(Double d) {
            Assert.isTrue(d.doubleValue() >= 0.0d && d.doubleValue() <= 1.0d, "The similarity threshold must be in range [0.0:1.00].");
            this.defaultSimilarityThreshold = d;
            return this;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public AnalyticDbVectorStore m2build() {
            try {
                return new AnalyticDbVectorStore(this);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    protected AnalyticDbVectorStore(Builder builder) throws Exception {
        super(builder);
        this.filterExpressionConverter = new AdVectorFilterExpressionConverter();
        this.collectionName = builder.collectionName;
        this.config = builder.config;
        this.client = builder.client;
        this.objectMapper = JsonMapper.builder().addModules(JacksonUtils.instantiateAvailableModules()).build();
        this.defaultSimilarityThreshold = builder.defaultSimilarityThreshold;
        this.defaultTopK = Integer.valueOf(builder.defaultTopK);
    }

    public static Builder builder(String str, AnalyticDbConfig analyticDbConfig, Client client, EmbeddingModel embeddingModel) {
        return new Builder(str, analyticDbConfig, client, embeddingModel);
    }

    private void initialize() throws Exception {
        initializeVectorDataBase();
        createNameSpaceIfNotExists();
        createCollectionIfNotExists(Long.valueOf(this.embeddingModel.dimensions()));
    }

    private void initializeVectorDataBase() throws Exception {
        logger.debug("successfully initialize vector database, response body:{}", this.client.initVectorDatabase(new InitVectorDatabaseRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setManagerAccount(this.config.getManagerAccount()).setManagerAccountPassword(this.config.getManagerAccountPassword())).getBody());
    }

    private void createNameSpaceIfNotExists() throws Exception {
        try {
            this.client.describeNamespace(new DescribeNamespaceRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setNamespace(this.config.getNamespace()).setManagerAccount(this.config.getManagerAccount()).setManagerAccountPassword(this.config.getManagerAccountPassword()));
        } catch (TeaException e) {
            if (!Objects.equals(e.getStatusCode(), 404)) {
                throw new Exception("failed to create namespace:{}", e);
            }
            this.client.createNamespace(new CreateNamespaceRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setNamespace(this.config.getNamespace()).setManagerAccount(this.config.getManagerAccount()).setManagerAccountPassword(this.config.getManagerAccountPassword()).setNamespacePassword(this.config.getNamespacePassword()));
        }
    }

    private void createCollectionIfNotExists(Long l) throws Exception {
        try {
            this.client.describeCollection(new DescribeCollectionRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setNamespace(this.config.getNamespace()).setNamespacePassword(this.config.getNamespacePassword()).setCollection(this.collectionName));
            logger.debug("collection" + this.collectionName + "already exists");
        } catch (TeaException e) {
            if (!Objects.equals(e.getStatusCode(), 404)) {
                throw new RuntimeException("Failed to create collection " + this.collectionName + ": " + e.getMessage());
            }
            ObjectNode createObjectNode = this.objectMapper.createObjectNode();
            createObjectNode.put(REF_DOC_NAME, "text");
            createObjectNode.put(CONTENT_FIELD_NAME, "text");
            createObjectNode.put(METADATA_FIELD_NAME, "jsonb");
            this.client.createCollection(new CreateCollectionRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setManagerAccount(this.config.getManagerAccount()).setManagerAccountPassword(this.config.getManagerAccountPassword()).setNamespace(this.config.getNamespace()).setCollection(this.collectionName).setDimension(l).setMetrics(this.config.getMetrics()).setMetadata(this.objectMapper.writeValueAsString(createObjectNode)).setFullTextRetrievalFields(CONTENT_FIELD_NAME));
            logger.debug("collection" + this.collectionName + "created");
        }
    }

    public void doAdd(List<Document> list) {
        Assert.notNull(list, "The document list should not be null.");
        if (CollectionUtils.isEmpty(list)) {
            return;
        }
        List embed = this.embeddingModel.embed(list, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        ArrayList arrayList = new ArrayList(10);
        for (int i = 0; i < list.size(); i++) {
            Document document = list.get(i);
            logger.info("Processing document id = {}", document.getId());
            HashMap hashMap = new HashMap();
            String str = (String) document.getMetadata().get(DOC_NAME);
            hashMap.put(REF_DOC_NAME, (str == null || str.isEmpty()) ? document.getId() : str);
            hashMap.put(CONTENT_FIELD_NAME, document.getText());
            try {
                hashMap.put(METADATA_FIELD_NAME, this.objectMapper.writeValueAsString(document.getMetadata()));
                float[] fArr = (float[]) embed.get(i);
                arrayList.add(new UpsertCollectionDataRequest.UpsertCollectionDataRequestRows().setVector(IntStream.range(0, fArr.length).mapToObj(i2 -> {
                    return Double.valueOf(fArr[i2]);
                }).toList()).setMetadata(hashMap));
            } catch (JsonProcessingException e) {
                throw new RuntimeException("Failed to serialize metadata for document id = " + document.getId(), e);
            }
        }
        try {
            this.client.upsertCollectionData(new UpsertCollectionDataRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setNamespace(this.config.getNamespace()).setNamespacePassword(this.config.getNamespacePassword()).setCollection(this.collectionName).setRows(arrayList));
        } catch (Exception e2) {
            throw new RuntimeException("Failed to add collection data by IDs: " + e2.getMessage(), e2);
        }
    }

    public void doDelete(List<String> list) {
        if (list.isEmpty()) {
            return;
        }
        try {
            logger.debug("delete collection data response:{}", this.client.deleteCollectionData(new DeleteCollectionDataRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setNamespace(this.config.getNamespace()).setNamespacePassword(this.config.getNamespacePassword()).setCollection(this.collectionName).setCollectionData((String) null).setCollectionDataFilter("refDocId IN " + ((String) list.stream().map(str -> {
                return "'" + str + "'";
            }).collect(Collectors.joining(", ", "(", ")"))))).getBody());
        } catch (Exception e) {
            throw new RuntimeException("Failed to delete collection data by IDs: " + e.getMessage(), e);
        }
    }

    public void doDelete(Filter.Expression expression) {
        try {
            logger.debug("delete collection data response:{}", this.client.deleteCollectionData(new DeleteCollectionDataRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setNamespace(this.config.getNamespace()).setNamespacePassword(this.config.getNamespacePassword()).setCollection(this.collectionName).setCollectionData((String) null).setCollectionDataFilter(this.filterExpressionConverter.convertExpression(expression))).getBody());
        } catch (Exception e) {
            throw new RuntimeException("Failed to delete collection data by filterExpression: " + e.getMessage(), e);
        }
    }

    public List<Document> similaritySearch(String str) {
        return similaritySearch(SearchRequest.builder().query(str).topK(this.defaultTopK.intValue()).similarityThreshold(this.defaultSimilarityThreshold.doubleValue()).build());
    }

    public List<Document> doSimilaritySearch(SearchRequest searchRequest) {
        double similarityThreshold = searchRequest.getSimilarityThreshold();
        boolean hasFilterExpression = searchRequest.hasFilterExpression();
        int topK = searchRequest.getTopK();
        try {
            QueryCollectionDataResponse queryCollectionData = this.client.queryCollectionData(new QueryCollectionDataRequest().setDBInstanceId(this.config.getDbInstanceId()).setRegionId(this.config.getRegionId()).setNamespace(this.config.getNamespace()).setNamespacePassword(this.config.getNamespacePassword()).setCollection(this.collectionName).setIncludeValues(Boolean.valueOf(hasFilterExpression)).setMetrics(this.config.getMetrics()).setVector((List) null).setContent(searchRequest.getQuery()).setTopK(Long.valueOf(topK)).setFilter(hasFilterExpression ? searchRequest.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(searchRequest.getFilterExpression()) : "" : null));
            ArrayList arrayList = new ArrayList();
            for (QueryCollectionDataResponseBody.QueryCollectionDataResponseBodyMatchesMatch queryCollectionDataResponseBodyMatchesMatch : queryCollectionData.getBody().getMatches().getMatch()) {
                if (queryCollectionDataResponseBodyMatchesMatch.getScore() != null && queryCollectionDataResponseBodyMatchesMatch.getScore().doubleValue() > similarityThreshold) {
                    Map metadata = queryCollectionDataResponseBodyMatchesMatch.getMetadata();
                    arrayList.add(new Document((String) metadata.get(CONTENT_FIELD_NAME), (Map) this.objectMapper.readValue((String) metadata.get(METADATA_FIELD_NAME), new TypeReference<HashMap<String, Object>>() { // from class: com.alibaba.cloud.ai.analyticdb.AnalyticDbVectorStore.1
                    })));
                }
            }
            return arrayList;
        } catch (Exception e) {
            throw new RuntimeException("Failed to search by full text: " + e.getMessage(), e);
        }
    }

    public void afterPropertiesSet() throws Exception {
        initialize();
        logger.debug("created AnalyticdbVector client success");
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String str) {
        return VectorStoreObservationContext.builder(DATA_BASE_SYSTEM, str).collectionName(this.collectionName).dimensions(Integer.valueOf(this.embeddingModel.dimensions())).namespace(this.config.getNamespace()).similarityMetric(this.config.getMetrics());
    }
}
