package org.grouplens.lenskit.knn.user;

import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongBidirectionalIterator;
import it.unimi.dsi.fastutil.longs.LongCollection;
import it.unimi.dsi.fastutil.longs.LongIterator;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import it.unimi.dsi.fastutil.longs.LongSet;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.PriorityQueue;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import org.grouplens.lenskit.data.dao.ItemEventDAO;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.event.Rating;
import org.grouplens.lenskit.data.event.Ratings;
import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.knn.NeighborhoodSize;
import org.grouplens.lenskit.transform.normalize.UserVectorNormalizer;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/knn/user/SimpleNeighborhoodFinder.class */
public class SimpleNeighborhoodFinder implements NeighborhoodFinder, Serializable {
    private static final long serialVersionUID = -6324767320394518347L;
    private static final Logger logger;
    private final UserEventDAO userDAO;
    private final ItemEventDAO itemDAO;
    private final int neighborhoodSize;
    private final UserSimilarity similarity;
    private final UserVectorNormalizer normalizer;
    private final Long2ObjectMap<CacheEntry> userVectorCache = new Long2ObjectOpenHashMap(500);
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/grouplens/lenskit/knn/user/SimpleNeighborhoodFinder$CacheEntry.class */
    public static class CacheEntry {
        final long user;
        final SparseVector ratings;
        final long lastRatingTimestamp;
        final int ratingCount;

        CacheEntry(long j, SparseVector sparseVector, long j2, int i) {
            this.user = j;
            this.ratings = sparseVector;
            this.lastRatingTimestamp = j2;
            this.ratingCount = i;
        }
    }

    @Inject
    public SimpleNeighborhoodFinder(UserEventDAO userEventDAO, ItemEventDAO itemEventDAO, @NeighborhoodSize int i, UserSimilarity userSimilarity, UserVectorNormalizer userVectorNormalizer) {
        this.userDAO = userEventDAO;
        this.itemDAO = itemEventDAO;
        this.neighborhoodSize = i;
        this.similarity = userSimilarity;
        this.normalizer = userVectorNormalizer;
    }

    @Override // org.grouplens.lenskit.knn.user.NeighborhoodFinder
    public Long2ObjectMap<? extends Collection<Neighbor>> findNeighbors(@Nonnull UserHistory<? extends Event> userHistory, @Nonnull LongSet longSet) {
        Preconditions.checkNotNull(userHistory, "user profile");
        Preconditions.checkNotNull(userHistory, "item set");
        Long2ObjectOpenHashMap long2ObjectOpenHashMap = new Long2ObjectOpenHashMap(longSet != null ? longSet.size() : 100);
        SparseVector makeRatingVector = RatingVectorUserHistorySummarizer.makeRatingVector(userHistory);
        long userId = userHistory.getUserId();
        SparseVector freeze = this.normalizer.normalize(userHistory.getUserId(), makeRatingVector, (MutableSparseVector) null).freeze();
        LongSet findRatingUsers = findRatingUsers(userHistory.getUserId(), longSet);
        logger.trace("Found {} candidate neighbors", Integer.valueOf(findRatingUsers.size()));
        LongIterator it = findRatingUsers.iterator();
        while (it.hasNext()) {
            long nextLong = it.nextLong();
            SparseVector userRatingVector = getUserRatingVector(nextLong);
            double similarity = this.similarity.similarity(userId, freeze, nextLong, this.normalizer.normalize(nextLong, userRatingVector, (MutableSparseVector) null));
            if (!Double.isNaN(similarity) && !Double.isInfinite(similarity)) {
                Neighbor neighbor = new Neighbor(nextLong, userRatingVector, similarity);
                LongBidirectionalIterator it2 = userRatingVector.keySet().iterator();
                while (it2.hasNext()) {
                    long nextLong2 = it2.nextLong();
                    if (longSet.contains(nextLong2)) {
                        PriorityQueue priorityQueue = (PriorityQueue) long2ObjectOpenHashMap.get(nextLong2);
                        if (priorityQueue == null) {
                            priorityQueue = new PriorityQueue(this.neighborhoodSize + 1, Neighbor.SIMILARITY_COMPARATOR);
                            long2ObjectOpenHashMap.put(nextLong2, priorityQueue);
                        }
                        priorityQueue.add(neighbor);
                        if (priorityQueue.size() <= this.neighborhoodSize) {
                            continue;
                        } else {
                            if (!$assertionsDisabled && priorityQueue.size() != this.neighborhoodSize + 1) {
                                throw new AssertionError();
                            }
                            priorityQueue.remove();
                        }
                    }
                }
            }
        }
        return long2ObjectOpenHashMap;
    }

    private LongSet findRatingUsers(long j, LongCollection longCollection) {
        LongOpenHashSet longOpenHashSet = new LongOpenHashSet(100);
        LongIterator it = longCollection.iterator();
        while (it.hasNext()) {
            LongSet usersForItem = this.itemDAO.getUsersForItem(it.nextLong());
            if (usersForItem != null) {
                longOpenHashSet.addAll(usersForItem);
            }
        }
        longOpenHashSet.remove(j);
        return longOpenHashSet;
    }

    private synchronized SparseVector getUserRatingVector(long j) {
        UserHistory eventsForUser = this.userDAO.getEventsForUser(j, Rating.class);
        CacheEntry cacheEntry = (CacheEntry) this.userVectorCache.get(j);
        if (cacheEntry != null && cacheEntry.ratingCount != eventsForUser.size()) {
            cacheEntry = null;
        }
        long j2 = -1;
        if (cacheEntry != null) {
            Iterator it = eventsForUser.iterator();
            while (it.hasNext()) {
                j2 = Math.max(j2, ((Rating) it.next()).getTimestamp());
            }
            if (j2 != cacheEntry.lastRatingTimestamp) {
                cacheEntry = null;
            }
        }
        if (cacheEntry == null) {
            cacheEntry = new CacheEntry(j, Ratings.userRatingVector(eventsForUser), j2, eventsForUser.size());
            this.userVectorCache.put(j, cacheEntry);
        }
        return cacheEntry.ratings;
    }

    static {
        $assertionsDisabled = !SimpleNeighborhoodFinder.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(SimpleNeighborhoodFinder.class);
    }
}
