package org.grouplens.lenskit.knn.user;

import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.longs.Long2ObjectMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import java.util.Collection;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import org.grouplens.lenskit.basic.AbstractItemScorer;
import org.grouplens.lenskit.data.dao.UserEventDAO;
import org.grouplens.lenskit.data.event.Event;
import org.grouplens.lenskit.data.history.History;
import org.grouplens.lenskit.data.history.RatingVectorUserHistorySummarizer;
import org.grouplens.lenskit.data.history.UserHistory;
import org.grouplens.lenskit.transform.normalize.UserVectorNormalizer;
import org.grouplens.lenskit.vectors.MutableSparseVector;
import org.grouplens.lenskit.vectors.SparseVector;
import org.grouplens.lenskit.vectors.VectorEntry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/grouplens/lenskit/knn/user/UserUserItemScorer.class */
public class UserUserItemScorer extends AbstractItemScorer {
    private static final double MINIMUM_SIMILARITY = 0.001d;
    private static final Logger logger = LoggerFactory.getLogger(UserUserItemScorer.class);
    private final UserEventDAO dao;
    protected final NeighborhoodFinder neighborhoodFinder;
    protected final UserVectorNormalizer normalizer;

    @Inject
    public UserUserItemScorer(UserEventDAO userEventDAO, NeighborhoodFinder neighborhoodFinder, UserVectorNormalizer userVectorNormalizer) {
        this.dao = userEventDAO;
        this.neighborhoodFinder = neighborhoodFinder;
        this.normalizer = userVectorNormalizer;
    }

    protected Long2ObjectMap<SparseVector> normalizeNeighborRatings(Collection<? extends Collection<Neighbor>> collection) {
        Long2ObjectOpenHashMap long2ObjectOpenHashMap = new Long2ObjectOpenHashMap();
        for (Neighbor neighbor : Iterables.concat(collection)) {
            if (!long2ObjectOpenHashMap.containsKey(neighbor.user)) {
                long2ObjectOpenHashMap.put(neighbor.user, this.normalizer.normalize(neighbor.user, neighbor.vector, (MutableSparseVector) null));
            }
        }
        return long2ObjectOpenHashMap;
    }

    public void score(long j, @Nonnull MutableSparseVector mutableSparseVector) {
        UserHistory<? extends Event> eventsForUser = this.dao.getEventsForUser(j);
        if (eventsForUser == null) {
            eventsForUser = History.forUser(j);
        }
        logger.trace("Predicting for user {} with {} events", Long.valueOf(j), Integer.valueOf(eventsForUser.size()));
        Long2ObjectMap<? extends Collection<Neighbor>> findNeighbors = this.neighborhoodFinder.findNeighbors(eventsForUser, mutableSparseVector.keyDomain());
        Long2ObjectMap<SparseVector> normalizeNeighborRatings = normalizeNeighborRatings(findNeighbors.values());
        for (VectorEntry vectorEntry : mutableSparseVector.fast(VectorEntry.State.EITHER)) {
            long key = vectorEntry.getKey();
            double d = 0.0d;
            double d2 = 0.0d;
            Collection<Neighbor> collection = (Collection) findNeighbors.get(key);
            if (collection != null) {
                for (Neighbor neighbor : collection) {
                    d2 += Math.abs(neighbor.similarity);
                    d += neighbor.similarity * ((SparseVector) normalizeNeighborRatings.get(neighbor.user)).get(key);
                }
            }
            if (d2 >= MINIMUM_SIMILARITY) {
                logger.trace("Total neighbor weight for item {} is {}", Long.valueOf(key), Double.valueOf(d2));
                mutableSparseVector.set(vectorEntry, d / d2);
            } else {
                mutableSparseVector.unset(vectorEntry);
            }
        }
        this.normalizer.makeTransformation(eventsForUser.getUserId(), RatingVectorUserHistorySummarizer.makeRatingVector(eventsForUser)).unapply(mutableSparseVector);
    }
}
