/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.community.rag.content.aggregator;

import dev.langchain4j.community.rag.content.aggregator.strategy.EmbeddingStrategy;
import dev.langchain4j.community.rag.content.aggregator.strategy.EmbeddingStrategyFactory;
import dev.langchain4j.community.store.embedding.MmrSelector;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.scoring.ScoringModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
import dev.langchain4j.rag.content.aggregator.ReciprocalRankFuser;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MmrContentAggregator
implements ContentAggregator {
    private static final Logger log = LoggerFactory.getLogger(MmrContentAggregator.class);
    private static final double DEFAULT_LAMBDA = 0.7;
    private static final Function<Map<Query, Collection<List<Content>>>, Query> DEFAULT_QUERY_SELECTOR = queryToContents -> {
        if (queryToContents.size() > 1) {
            throw new IllegalArgumentException(String.format("The 'queryToContents' contains %s queries, making MMR ambiguous. Please provide a 'querySelector' in the constructor/builder.", queryToContents.size()));
        }
        return (Query)queryToContents.keySet().iterator().next();
    };
    private final EmbeddingModel embeddingModel;
    private final ScoringModel scoringModel;
    private final Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
    private final Double minScore;
    private final Integer maxResults;
    private final double lambda;
    private final boolean forceEmbeddingGeneration;
    private final EmbeddingStrategy manualStrategy;

    public MmrContentAggregator(EmbeddingModel embeddingModel) {
        this(embeddingModel, null, DEFAULT_QUERY_SELECTOR, null, Integer.MAX_VALUE, 0.7, false, null);
    }

    public MmrContentAggregator(EmbeddingModel embeddingModel, double lambda) {
        this(embeddingModel, null, DEFAULT_QUERY_SELECTOR, null, Integer.MAX_VALUE, lambda, false, null);
    }

    public MmrContentAggregator(EmbeddingModel embeddingModel, boolean forceEmbeddingGeneration) {
        this(embeddingModel, null, DEFAULT_QUERY_SELECTOR, null, Integer.MAX_VALUE, 0.7, forceEmbeddingGeneration, null);
    }

    public MmrContentAggregator(EmbeddingModel embeddingModel, EmbeddingStrategy strategy) {
        this(embeddingModel, null, DEFAULT_QUERY_SELECTOR, null, Integer.MAX_VALUE, 0.7, false, strategy);
    }

    public MmrContentAggregator(EmbeddingModel embeddingModel, ScoringModel scoringModel, Function<Map<Query, Collection<List<Content>>>, Query> querySelector, Double minScore, Integer maxResults, double lambda, boolean forceEmbeddingGeneration, EmbeddingStrategy manualStrategy) {
        this.embeddingModel = forceEmbeddingGeneration || manualStrategy != null ? embeddingModel : (EmbeddingModel)ValidationUtils.ensureNotNull((Object)embeddingModel, (String)"embeddingModel");
        this.scoringModel = scoringModel;
        this.querySelector = (Function)Utils.getOrDefault(querySelector, DEFAULT_QUERY_SELECTOR);
        this.minScore = minScore;
        this.maxResults = (Integer)Utils.getOrDefault((Object)maxResults, (Object)Integer.MAX_VALUE);
        this.lambda = lambda;
        this.forceEmbeddingGeneration = forceEmbeddingGeneration;
        this.manualStrategy = manualStrategy;
        if (forceEmbeddingGeneration && manualStrategy != null) {
            log.warn("Both forceEmbeddingGeneration and manualStrategy provided. Manual strategy takes precedence.");
        }
        if (scoringModel != null) {
            log.warn("ScoringModel provided but hybrid MMR-reranking is not yet implemented. Currently using cosine similarity only. TODO: Implement hybrid approach combining re-ranking scores with MMR diversity.");
        }
        if (manualStrategy != null) {
            log.info("MMR configured with manual strategy: {}", (Object)manualStrategy.getClass().getSimpleName());
        } else if (forceEmbeddingGeneration) {
            log.info("MMR configured to force embedding generation regardless of existing embeddings");
        } else {
            log.info("MMR configured to automatically select optimal embedding strategy based on content analysis");
        }
    }

    public static MmrContentAggregatorBuilder builder() {
        return new MmrContentAggregatorBuilder();
    }

    public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {
        if (queryToContents.isEmpty()) {
            return Collections.emptyList();
        }
        Query query = this.querySelector.apply(queryToContents);
        Map<Query, List<Content>> queryToFusedContents = this.fuse(queryToContents);
        List fusedContents = ReciprocalRankFuser.fuse(queryToFusedContents.values());
        if (fusedContents.isEmpty()) {
            return fusedContents;
        }
        if (this.maxResults < Integer.MAX_VALUE && fusedContents.size() < 5 * this.maxResults) {
            log.warn("Pre-MMR candidate count is lower than expected: {} items (recommended: 5\u201310\u00d7 maxResults, current range: {}\u2013{})", new Object[]{fusedContents.size(), 5 * this.maxResults, 10 * this.maxResults});
        }
        return this.applyMmr(fusedContents, query);
    }

    private Map<Query, List<Content>> fuse(Map<Query, Collection<List<Content>>> queryToContents) {
        return queryToContents.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> ReciprocalRankFuser.fuse((Collection)((Collection)entry.getValue()))));
    }

    private List<Content> applyMmr(List<Content> contents, Query query) {
        EmbeddingStrategy strategy;
        if (this.manualStrategy != null) {
            strategy = this.manualStrategy;
            log.debug("Using manual strategy: {}", (Object)strategy.getClass().getSimpleName());
        } else {
            strategy = EmbeddingStrategyFactory.createStrategy(contents, this.forceEmbeddingGeneration);
        }
        Embedding queryEmbedding = strategy.processQueryEmbedding(query, contents, this.embeddingModel);
        List matches = strategy.processContents(contents, queryEmbedding, this.embeddingModel);
        if (this.minScore != null) {
            matches = matches.stream().filter(match -> match.score() >= this.minScore).collect(Collectors.toList());
        }
        int resultsToSelect = Math.min(this.maxResults, matches.size());
        return MmrSelector.select(queryEmbedding, matches, resultsToSelect, this.lambda).stream().map(EmbeddingMatch::embedded).collect(Collectors.toList());
    }

    public static class MmrContentAggregatorBuilder {
        private EmbeddingModel embeddingModel;
        private ScoringModel scoringModel;
        private Function<Map<Query, Collection<List<Content>>>, Query> querySelector;
        private Double minScore;
        private Integer maxResults;
        private Double lambda;
        private Boolean forceEmbeddingGeneration;
        private EmbeddingStrategy manualStrategy;

        MmrContentAggregatorBuilder() {
        }

        public MmrContentAggregatorBuilder embeddingModel(EmbeddingModel embeddingModel) {
            this.embeddingModel = embeddingModel;
            return this;
        }

        public MmrContentAggregatorBuilder scoringModel(ScoringModel scoringModel) {
            this.scoringModel = scoringModel;
            return this;
        }

        public MmrContentAggregatorBuilder querySelector(Function<Map<Query, Collection<List<Content>>>, Query> querySelector) {
            this.querySelector = querySelector;
            return this;
        }

        public MmrContentAggregatorBuilder minScore(Double minScore) {
            this.minScore = minScore;
            return this;
        }

        public MmrContentAggregatorBuilder maxResults(Integer maxResults) {
            this.maxResults = maxResults;
            return this;
        }

        public MmrContentAggregatorBuilder lambda(Double lambda) {
            this.lambda = lambda;
            return this;
        }

        public MmrContentAggregatorBuilder forceEmbeddingGeneration(Boolean forceEmbeddingGeneration) {
            this.forceEmbeddingGeneration = forceEmbeddingGeneration;
            return this;
        }

        public MmrContentAggregatorBuilder strategy(EmbeddingStrategy strategy) {
            this.manualStrategy = strategy;
            return this;
        }

        public MmrContentAggregator build() {
            boolean forceGeneration = (Boolean)Utils.getOrDefault((Object)this.forceEmbeddingGeneration, (Object)false);
            return new MmrContentAggregator(this.embeddingModel, this.scoringModel, this.querySelector, this.minScore, this.maxResults, (Double)Utils.getOrDefault((Object)this.lambda, (Object)0.7), forceGeneration, this.manualStrategy);
        }
    }
}

