package org.nlpub.watset.graph;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.jgrapht.Graph;
import org.nlpub.watset.util.Maximizer;
import org.nlpub.watset.util.Neighbors;

/* loaded from: input_file:org/nlpub/watset/graph/ChineseWhispers.class */
public class ChineseWhispers<V, E> implements Clustering<V> {
    public static final int ITERATIONS = 20;
    protected final Graph<V, E> graph;
    protected final NodeWeighting<V, E> weighting;
    protected final int iterations;
    protected final Random random;
    protected Map<V, Integer> labels;
    protected int steps;

    public static <V, E> Function<Graph<V, E>, Clustering<V>> provider(NodeWeighting<V, E> nodeWeighting) {
        return graph -> {
            return new ChineseWhispers(graph, nodeWeighting);
        };
    }

    public static <V, E> Function<Graph<V, E>, Clustering<V>> provider(NodeWeighting<V, E> nodeWeighting, int i, Random random) {
        return graph -> {
            return new ChineseWhispers(graph, nodeWeighting, i, random);
        };
    }

    public ChineseWhispers(Graph<V, E> graph, NodeWeighting<V, E> nodeWeighting, int i, Random random) {
        this.graph = (Graph) Objects.requireNonNull(graph);
        this.weighting = (NodeWeighting) Objects.requireNonNull(nodeWeighting);
        this.iterations = i;
        this.random = (Random) Objects.requireNonNull(random);
    }

    public ChineseWhispers(Graph<V, E> graph, NodeWeighting<V, E> nodeWeighting) {
        this(graph, nodeWeighting, 20, new Random());
    }

    @Override // org.nlpub.watset.graph.Clustering
    public void fit() {
        ArrayList arrayList = new ArrayList(this.graph.vertexSet());
        this.labels = new HashMap(arrayList.size());
        int i = 0;
        Iterator<E> it = this.graph.vertexSet().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            this.labels.put(it.next(), Integer.valueOf(i2));
        }
        this.steps = 0;
        while (this.steps < this.iterations) {
            Collections.shuffle(arrayList, this.random);
            if (step(arrayList) == 0) {
                return;
            } else {
                this.steps++;
            }
        }
    }

    protected int step(List<V> list) {
        int i = 0;
        for (V v : list) {
            Optional argmaxRandom = Maximizer.argmaxRandom(score(this.graph, this.labels, this.weighting, v).entrySet().iterator(), (v0) -> {
                return v0.getValue();
            }, this.random);
            int intValue = argmaxRandom.isPresent() ? ((Integer) ((Map.Entry) argmaxRandom.get()).getKey()).intValue() : this.labels.get(v).intValue();
            if (this.labels.put(v, Integer.valueOf(intValue)).intValue() != intValue) {
                i++;
            }
        }
        return i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nlpub.watset.graph.Clustering
    public Collection<Collection<V>> getClusters() {
        Objects.requireNonNull(this.labels, "call fit() first");
        Map map = (Map) this.labels.entrySet().stream().collect(Collectors.groupingBy((v0) -> {
            return v0.getValue();
        }));
        ArrayList arrayList = new ArrayList(map.size());
        Iterator<V> it = map.values().iterator();
        while (it.hasNext()) {
            arrayList.add(((List) it.next()).stream().map((v0) -> {
                return v0.getKey();
            }).collect(Collectors.toSet()));
        }
        return arrayList;
    }

    protected Map<Integer, Double> score(Graph<V, E> graph, Map<V, Integer> map, NodeWeighting<V, E> nodeWeighting, V v) {
        HashMap hashMap = new HashMap();
        Neighbors.neighborIterator(graph, v).forEachRemaining(obj -> {
            hashMap.merge(Integer.valueOf(((Integer) map.get(obj)).intValue()), Double.valueOf(nodeWeighting.apply(graph, map, v, obj)), (v0, v1) -> {
                return Double.sum(v0, v1);
            });
        });
        return hashMap;
    }

    public int getIterations() {
        return this.iterations;
    }

    public int getSteps() {
        return this.steps;
    }
}
