package org.springframework.ai.vectorstore.redis;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.RediSearchUtil;
import redis.clients.jedis.search.Schema;
import redis.clients.jedis.search.schemafields.NumericField;
import redis.clients.jedis.search.schemafields.SchemaField;
import redis.clients.jedis.search.schemafields.TagField;
import redis.clients.jedis.search.schemafields.TextField;
import redis.clients.jedis.search.schemafields.VectorField;

/* loaded from: input_file:org/springframework/ai/vectorstore/redis/RedisVectorStore.class */
public class RedisVectorStore extends AbstractObservationVectorStore implements InitializingBean {
    public static final String DEFAULT_INDEX_NAME = "spring-ai-index";
    public static final String DEFAULT_CONTENT_FIELD_NAME = "content";
    public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding";
    public static final String DEFAULT_PREFIX = "embedding:";
    public static final String DISTANCE_FIELD_NAME = "vector_score";
    private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]";
    private static final String JSON_PATH_PREFIX = "$.";
    private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32";
    private static final String EMBEDDING_PARAM_NAME = "BLOB";
    private static final String DEFAULT_DISTANCE_METRIC = "COSINE";
    private final JedisPooled jedis;
    private final boolean initializeSchema;
    private final String indexName;
    private final String prefix;
    private final String contentFieldName;
    private final String embeddingFieldName;
    private final Algorithm vectorAlgorithm;
    private final List<MetadataField> metadataFields;
    private final FilterExpressionConverter filterExpressionConverter;
    public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW;
    private static final Path2 JSON_SET_PATH = Path2.of("$");
    private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class);
    private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK");
    private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1L);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.springframework.ai.vectorstore.redis.RedisVectorStore$1, reason: invalid class name */
    /* loaded from: input_file:org/springframework/ai/vectorstore/redis/RedisVectorStore$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$redis$clients$jedis$search$Schema$FieldType = new int[Schema.FieldType.values().length];

        static {
            try {
                $SwitchMap$redis$clients$jedis$search$Schema$FieldType[Schema.FieldType.NUMERIC.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$redis$clients$jedis$search$Schema$FieldType[Schema.FieldType.TAG.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$redis$clients$jedis$search$Schema$FieldType[Schema.FieldType.TEXT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/springframework/ai/vectorstore/redis/RedisVectorStore$Algorithm.class */
    public enum Algorithm {
        FLAT,
        HSNW
    }

    /* loaded from: input_file:org/springframework/ai/vectorstore/redis/RedisVectorStore$Builder.class */
    public static class Builder extends AbstractVectorStoreBuilder<Builder> {
        private final JedisPooled jedis;
        private String indexName;
        private String prefix;
        private String contentFieldName;
        private String embeddingFieldName;
        private Algorithm vectorAlgorithm;
        private List<MetadataField> metadataFields;
        private boolean initializeSchema;

        private Builder(JedisPooled jedisPooled, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            this.indexName = RedisVectorStore.DEFAULT_INDEX_NAME;
            this.prefix = RedisVectorStore.DEFAULT_PREFIX;
            this.contentFieldName = RedisVectorStore.DEFAULT_CONTENT_FIELD_NAME;
            this.embeddingFieldName = RedisVectorStore.DEFAULT_EMBEDDING_FIELD_NAME;
            this.vectorAlgorithm = RedisVectorStore.DEFAULT_VECTOR_ALGORITHM;
            this.metadataFields = new ArrayList();
            this.initializeSchema = false;
            Assert.notNull(jedisPooled, "JedisPooled must not be null");
            this.jedis = jedisPooled;
        }

        public Builder indexName(String str) {
            if (StringUtils.hasText(str)) {
                this.indexName = str;
            }
            return this;
        }

        public Builder prefix(String str) {
            if (StringUtils.hasText(str)) {
                this.prefix = str;
            }
            return this;
        }

        public Builder contentFieldName(String str) {
            if (StringUtils.hasText(str)) {
                this.contentFieldName = str;
            }
            return this;
        }

        public Builder embeddingFieldName(String str) {
            if (StringUtils.hasText(str)) {
                this.embeddingFieldName = str;
            }
            return this;
        }

        public Builder vectorAlgorithm(@Nullable Algorithm algorithm) {
            if (algorithm != null) {
                this.vectorAlgorithm = algorithm;
            }
            return this;
        }

        public Builder metadataFields(MetadataField... metadataFieldArr) {
            return metadataFields(Arrays.asList(metadataFieldArr));
        }

        public Builder metadataFields(@Nullable List<MetadataField> list) {
            if (list != null && !list.isEmpty()) {
                this.metadataFields = new ArrayList(list);
            }
            return this;
        }

        public Builder initializeSchema(boolean z) {
            this.initializeSchema = z;
            return this;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public RedisVectorStore m5build() {
            return new RedisVectorStore(this);
        }
    }

    /* loaded from: input_file:org/springframework/ai/vectorstore/redis/RedisVectorStore$MetadataField.class */
    public static final class MetadataField extends Record {
        private final String name;
        private final Schema.FieldType fieldType;

        public MetadataField(String str, Schema.FieldType fieldType) {
            this.name = str;
            this.fieldType = fieldType;
        }

        public static MetadataField text(String str) {
            return new MetadataField(str, Schema.FieldType.TEXT);
        }

        public static MetadataField numeric(String str) {
            return new MetadataField(str, Schema.FieldType.NUMERIC);
        }

        public static MetadataField tag(String str) {
            return new MetadataField(str, Schema.FieldType.TAG);
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, MetadataField.class), MetadataField.class, "name;fieldType", "FIELD:Lorg/springframework/ai/vectorstore/redis/RedisVectorStore$MetadataField;->name:Ljava/lang/String;", "FIELD:Lorg/springframework/ai/vectorstore/redis/RedisVectorStore$MetadataField;->fieldType:Lredis/clients/jedis/search/Schema$FieldType;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, MetadataField.class), MetadataField.class, "name;fieldType", "FIELD:Lorg/springframework/ai/vectorstore/redis/RedisVectorStore$MetadataField;->name:Ljava/lang/String;", "FIELD:Lorg/springframework/ai/vectorstore/redis/RedisVectorStore$MetadataField;->fieldType:Lredis/clients/jedis/search/Schema$FieldType;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, MetadataField.class, Object.class), MetadataField.class, "name;fieldType", "FIELD:Lorg/springframework/ai/vectorstore/redis/RedisVectorStore$MetadataField;->name:Ljava/lang/String;", "FIELD:Lorg/springframework/ai/vectorstore/redis/RedisVectorStore$MetadataField;->fieldType:Lredis/clients/jedis/search/Schema$FieldType;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public String name() {
            return this.name;
        }

        public Schema.FieldType fieldType() {
            return this.fieldType;
        }
    }

    protected RedisVectorStore(Builder builder) {
        super(builder);
        Assert.notNull(builder.jedis, "JedisPooled must not be null");
        this.jedis = builder.jedis;
        this.indexName = builder.indexName;
        this.prefix = builder.prefix;
        this.contentFieldName = builder.contentFieldName;
        this.embeddingFieldName = builder.embeddingFieldName;
        this.vectorAlgorithm = builder.vectorAlgorithm;
        this.metadataFields = builder.metadataFields;
        this.initializeSchema = builder.initializeSchema;
        this.filterExpressionConverter = new RedisFilterExpressionConverter(this.metadataFields);
    }

    public JedisPooled getJedis() {
        return this.jedis;
    }

    public void doAdd(List<Document> list) {
        Pipeline pipelined = this.jedis.pipelined();
        try {
            List embed = this.embeddingModel.embed(list, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
            for (Document document : list) {
                HashMap hashMap = new HashMap();
                hashMap.put(this.embeddingFieldName, embed.get(list.indexOf(document)));
                hashMap.put(this.contentFieldName, document.getText());
                hashMap.putAll(document.getMetadata());
                pipelined.jsonSetWithEscape(key(document.getId()), JSON_SET_PATH, hashMap);
            }
            Optional findAny = pipelined.syncAndReturnAll().stream().filter(Predicate.not(RESPONSE_OK)).findAny();
            if (findAny.isPresent()) {
                String format = MessageFormat.format("Could not add document: {0}", findAny.get());
                if (logger.isErrorEnabled()) {
                    logger.error(format);
                }
                throw new RuntimeException(format);
            }
            if (pipelined != null) {
                pipelined.close();
            }
        } catch (Throwable th) {
            if (pipelined != null) {
                try {
                    pipelined.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private String key(String str) {
        return this.prefix + str;
    }

    public void doDelete(List<String> list) {
        Pipeline pipelined = this.jedis.pipelined();
        try {
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                pipelined.jsonDel(key(it.next()));
            }
            Optional findAny = pipelined.syncAndReturnAll().stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
            if (findAny.isPresent() && logger.isErrorEnabled()) {
                logger.error("Could not delete document: {}", findAny.get());
            }
            if (pipelined != null) {
                pipelined.close();
            }
        } catch (Throwable th) {
            if (pipelined != null) {
                try {
                    pipelined.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    protected void doDelete(Filter.Expression expression) {
        Assert.notNull(expression, "Filter expression must not be null");
        try {
            String convertExpression = this.filterExpressionConverter.convertExpression(expression);
            ArrayList arrayList = new ArrayList();
            Iterator it = this.jedis.ftSearch(this.indexName, convertExpression).getDocuments().iterator();
            while (it.hasNext()) {
                arrayList.add(((redis.clients.jedis.search.Document) it.next()).getId().replace(key(""), ""));
            }
            if (!arrayList.isEmpty()) {
                Pipeline pipelined = this.jedis.pipelined();
                try {
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        pipelined.jsonDel(key((String) it2.next()));
                    }
                    Optional findAny = pipelined.syncAndReturnAll().stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
                    if (findAny.isPresent()) {
                        logger.error("Could not delete document: {}", findAny.get());
                        throw new IllegalStateException("Failed to delete some documents");
                    }
                    if (pipelined != null) {
                        pipelined.close();
                    }
                    logger.debug("Deleted {} documents matching filter expression", Integer.valueOf(arrayList.size()));
                } finally {
                }
            }
        } catch (Exception e) {
            logger.error("Failed to delete documents by filter", e);
            throw new IllegalStateException("Failed to delete documents by filter", e);
        }
    }

    public List<Document> doSimilaritySearch(SearchRequest searchRequest) {
        Assert.isTrue(searchRequest.getTopK() > 0, "The number of documents to be returned must be greater than zero");
        Assert.isTrue(searchRequest.getSimilarityThreshold() >= 0.0d && searchRequest.getSimilarityThreshold() <= 1.0d, "The similarity score is bounded between 0 and 1; least to most similar respectively.");
        String format = String.format(QUERY_FORMAT, nativeExpressionFilter(searchRequest), Integer.valueOf(searchRequest.getTopK()), this.embeddingFieldName, EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
        ArrayList arrayList = new ArrayList();
        Stream<R> map = this.metadataFields.stream().map((v0) -> {
            return v0.name();
        });
        Objects.requireNonNull(arrayList);
        map.forEach((v1) -> {
            r1.add(v1);
        });
        arrayList.add(this.embeddingFieldName);
        arrayList.add(this.contentFieldName);
        arrayList.add(DISTANCE_FIELD_NAME);
        return this.jedis.ftSearch(this.indexName, new Query(format).addParam(EMBEDDING_PARAM_NAME, RediSearchUtil.toByteArray(this.embeddingModel.embed(searchRequest.getQuery()))).returnFields((String[]) arrayList.toArray(new String[0])).setSortBy(DISTANCE_FIELD_NAME, true).limit(0, Integer.valueOf(searchRequest.getTopK())).dialect(2)).getDocuments().stream().filter(document -> {
            return ((double) similarityScore(document)) >= searchRequest.getSimilarityThreshold();
        }).map(this::toDocument).toList();
    }

    private Document toDocument(redis.clients.jedis.search.Document document) {
        String substring = document.getId().substring(this.prefix.length());
        String string = document.hasProperty(this.contentFieldName) ? document.getString(this.contentFieldName) : "";
        Stream<R> map = this.metadataFields.stream().map((v0) -> {
            return v0.name();
        });
        Objects.requireNonNull(document);
        Stream filter = map.filter(document::hasProperty);
        Function identity = Function.identity();
        Objects.requireNonNull(document);
        Map map2 = (Map) filter.collect(Collectors.toMap(identity, document::getString));
        map2.put(DISTANCE_FIELD_NAME, Float.valueOf(1.0f - similarityScore(document)));
        map2.put(DocumentMetadata.DISTANCE.value(), Float.valueOf(1.0f - similarityScore(document)));
        return Document.builder().id(substring).text(string).metadata(map2).score(Double.valueOf(similarityScore(document))).build();
    }

    private float similarityScore(redis.clients.jedis.search.Document document) {
        return (2.0f - Float.parseFloat(document.getString(DISTANCE_FIELD_NAME))) / 2.0f;
    }

    private String nativeExpressionFilter(SearchRequest searchRequest) {
        return searchRequest.getFilterExpression() == null ? "*" : "(" + this.filterExpressionConverter.convertExpression(searchRequest.getFilterExpression()) + ")";
    }

    public void afterPropertiesSet() {
        if (this.initializeSchema && !this.jedis.ftList().contains(this.indexName)) {
            String ftCreate = this.jedis.ftCreate(this.indexName, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.prefix), schemaFields());
            if (!RESPONSE_OK.test(ftCreate)) {
                throw new RuntimeException(MessageFormat.format("Could not create index: {0}", ftCreate));
            }
        }
    }

    private Iterable<SchemaField> schemaFields() {
        HashMap hashMap = new HashMap();
        hashMap.put("DIM", Integer.valueOf(this.embeddingModel.dimensions()));
        hashMap.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC);
        hashMap.put("TYPE", VECTOR_TYPE_FLOAT32);
        ArrayList arrayList = new ArrayList();
        arrayList.add(TextField.of(jsonPath(this.contentFieldName)).as(this.contentFieldName).weight(1.0d));
        arrayList.add(VectorField.builder().fieldName(jsonPath(this.embeddingFieldName)).algorithm(vectorAlgorithm()).attributes(hashMap).as(this.embeddingFieldName).build());
        if (!CollectionUtils.isEmpty(this.metadataFields)) {
            Iterator<MetadataField> it = this.metadataFields.iterator();
            while (it.hasNext()) {
                arrayList.add(schemaField(it.next()));
            }
        }
        return arrayList;
    }

    private SchemaField schemaField(MetadataField metadataField) {
        String jsonPath = jsonPath(metadataField.name);
        switch (AnonymousClass1.$SwitchMap$redis$clients$jedis$search$Schema$FieldType[metadataField.fieldType.ordinal()]) {
            case 1:
                return NumericField.of(jsonPath).as(metadataField.name);
            case 2:
                return TagField.of(jsonPath).as(metadataField.name);
            case 3:
                return TextField.of(jsonPath).as(metadataField.name);
            default:
                throw new IllegalArgumentException(MessageFormat.format("Field {0} has unsupported type {1}", metadataField.name, metadataField.fieldType));
        }
    }

    private VectorField.VectorAlgorithm vectorAlgorithm() {
        return this.vectorAlgorithm == Algorithm.HSNW ? VectorField.VectorAlgorithm.HNSW : VectorField.VectorAlgorithm.FLAT;
    }

    private String jsonPath(String str) {
        return "$." + str;
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String str) {
        return VectorStoreObservationContext.builder(VectorStoreProvider.REDIS.value(), str).collectionName(this.indexName).dimensions(Integer.valueOf(this.embeddingModel.dimensions())).fieldName(this.embeddingFieldName).similarityMetric(VectorStoreSimilarityMetric.COSINE.value());
    }

    public <T> Optional<T> getNativeClient() {
        return Optional.of(this.jedis);
    }

    public static Builder builder(JedisPooled jedisPooled, EmbeddingModel embeddingModel) {
        return new Builder(jedisPooled, embeddingModel);
    }
}
