/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.rag.content.aggregator.strategy;

import dev.langchain4j.community.rag.content.aggregator.strategy.EmbeddingStrategy;
import dev.langchain4j.community.rag.content.util.EmbeddingMetadataUtils;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.ContentMetadata;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class HybridEmbeddings
implements EmbeddingStrategy {
    private static final Logger log = LoggerFactory.getLogger(HybridEmbeddings.class);
    private static final String TEMP_EMBEDDING_ID_PREFIX = "mmr-content-";

    @Override
    public Embedding processQueryEmbedding(Query query, List<Content> contents, EmbeddingModel embeddingModel) {
        Optional<Embedding> existingQueryEmbedding = contents.stream().map(content -> EmbeddingMetadataUtils.extractQueryEmbedding(content.textSegment())).filter(Objects::nonNull).findFirst();
        if (existingQueryEmbedding.isPresent()) {
            log.debug("Using existing query embedding from content metadata");
            return existingQueryEmbedding.get();
        }
        log.debug("Generating query embedding as not found in content metadata");
        return (Embedding)embeddingModel.embed(query.text()).content();
    }

    @Override
    public List<EmbeddingMatch<Content>> processContents(List<Content> contents, Embedding queryEmbedding, EmbeddingModel embeddingModel) {
        Map<Boolean, List<Content>> partitioned = contents.stream().collect(Collectors.partitioningBy(this::hasEmbedding));
        List<Content> withEmbeddings = partitioned.get(true);
        List<Content> withoutEmbeddings = partitioned.get(false);
        ArrayList<EmbeddingMatch<Content>> matches = new ArrayList<EmbeddingMatch<Content>>();
        if (!withEmbeddings.isEmpty()) {
            log.debug("Processing {} contents with existing embeddings", (Object)withEmbeddings.size());
            matches.addAll(this.processExistingEmbeddings(withEmbeddings, queryEmbedding));
        }
        if (!withoutEmbeddings.isEmpty()) {
            log.debug("Generating embeddings for {} contents", (Object)withoutEmbeddings.size());
            matches.addAll(this.generateMissingEmbeddings(withoutEmbeddings, queryEmbedding, embeddingModel));
        }
        log.debug("Processed {} total contents ({} existing, {} generated)", new Object[]{contents.size(), withEmbeddings.size(), withoutEmbeddings.size()});
        return matches;
    }

    private boolean hasEmbedding(Content content) {
        return EmbeddingMetadataUtils.extractDocumentEmbedding(content.textSegment()) != null;
    }

    private List<EmbeddingMatch<Content>> processExistingEmbeddings(List<Content> contents, Embedding queryEmbedding) {
        return contents.stream().map(content -> {
            Embedding contentEmbedding = EmbeddingMetadataUtils.extractDocumentEmbedding(content.textSegment());
            double score = CosineSimilarity.between((Embedding)contentEmbedding, (Embedding)queryEmbedding);
            String embeddingId = this.getEmbeddingId((Content)content);
            return new EmbeddingMatch(Double.valueOf(score), embeddingId, contentEmbedding, content);
        }).collect(Collectors.toList());
    }

    private List<EmbeddingMatch<Content>> generateMissingEmbeddings(List<Content> contents, Embedding queryEmbedding, EmbeddingModel embeddingModel) {
        List textSegments = contents.stream().map(Content::textSegment).collect(Collectors.toList());
        List embeddings = (List)embeddingModel.embedAll(textSegments).content();
        ArrayList<EmbeddingMatch<Content>> matches = new ArrayList<EmbeddingMatch<Content>>();
        for (int i = 0; i < contents.size(); ++i) {
            Content content = contents.get(i);
            Embedding embedding = (Embedding)embeddings.get(i);
            double score = CosineSimilarity.between((Embedding)embedding, (Embedding)queryEmbedding);
            String embeddingId = this.getEmbeddingId(content);
            matches.add((EmbeddingMatch<Content>)new EmbeddingMatch(Double.valueOf(score), embeddingId, embedding, (Object)content));
        }
        return matches;
    }

    private String getEmbeddingId(Content content) {
        Object embeddingId = content.metadata().get(ContentMetadata.EMBEDDING_ID);
        if (embeddingId instanceof String && !((String)embeddingId).isBlank()) {
            return (String)embeddingId;
        }
        return TEMP_EMBEDDING_ID_PREFIX + Math.abs(content.hashCode());
    }
}

