package org.nlpub.watset.graph;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.math3.linear.DefaultRealMatrixChangingVisitor;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.jgrapht.Graph;

/* loaded from: input_file:org/nlpub/watset/graph/MarkovClustering.class */
public class MarkovClustering<V, E> implements Clustering<V> {
    public static final Integer ITERATIONS = 20;
    protected final Graph<V, E> graph;
    protected final int e;
    protected final double r;
    protected final MarkovClustering<V, E>.InflateVisitor inflateVisitor = new InflateVisitor();
    protected RealMatrix matrix;
    protected RealMatrix ones;
    protected Map<V, Integer> index;

    /* loaded from: input_file:org/nlpub/watset/graph/MarkovClustering$InflateVisitor.class */
    public class InflateVisitor extends DefaultRealMatrixChangingVisitor {
        private InflateVisitor() {
        }

        public double visit(int i, int i2, double d) {
            return Math.pow(d, MarkovClustering.this.r);
        }
    }

    /* loaded from: input_file:org/nlpub/watset/graph/MarkovClustering$NormalizeVisitor.class */
    public static class NormalizeVisitor extends DefaultRealMatrixChangingVisitor {
        private final RealMatrix sums;

        public NormalizeVisitor(RealMatrix realMatrix) {
            this.sums = realMatrix;
        }

        public double visit(int i, int i2, double d) {
            return d / this.sums.getEntry(0, i2);
        }
    }

    public static <V, E> Function<Graph<V, E>, Clustering<V>> provider(int i, double d) {
        return graph -> {
            return new MarkovClustering(graph, i, d);
        };
    }

    public MarkovClustering(Graph<V, E> graph, int i, double d) {
        this.graph = (Graph) Objects.requireNonNull(graph);
        this.e = i;
        this.r = d;
    }

    @Override // org.nlpub.watset.graph.Clustering
    public void fit() {
        this.index = null;
        this.matrix = null;
        this.ones = null;
        if (this.graph.vertexSet().isEmpty()) {
            return;
        }
        this.index = buildIndex();
        this.matrix = buildMatrix(this.index);
        double[] dArr = new double[this.matrix.getRowDimension()];
        Arrays.fill(dArr, 1.0d);
        this.ones = MatrixUtils.createRowRealMatrix(dArr);
        normalize();
        for (int i = 0; i < ITERATIONS.intValue(); i++) {
            RealMatrix copy = this.matrix.copy();
            expand();
            inflate();
            if (this.matrix.equals(copy)) {
                return;
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nlpub.watset.graph.Clustering
    public Collection<Collection<V>> getClusters() {
        Objects.requireNonNull(this.index, "call fit() first");
        Objects.requireNonNull(this.matrix, "call fit() first");
        if (this.graph.vertexSet().isEmpty()) {
            return Collections.emptySet();
        }
        Map map = (Map) this.index.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getValue();
        }, (v0) -> {
            return v0.getKey();
        }));
        HashSet hashSet = new HashSet();
        for (int i = 0; i < this.matrix.getRowDimension(); i++) {
            HashSet hashSet2 = new HashSet();
            for (int i2 = 0; i2 < this.matrix.getColumnDimension(); i2++) {
                if (this.matrix.getEntry(i, i2) > 0.0d) {
                    hashSet2.add(map.get(Integer.valueOf(i2)));
                }
            }
            if (!hashSet2.isEmpty()) {
                hashSet.add(hashSet2);
            }
        }
        return hashSet;
    }

    protected Map<V, Integer> buildIndex() {
        HashMap hashMap = new HashMap();
        int i = 0;
        Iterator<E> it = this.graph.vertexSet().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            hashMap.put(it.next(), Integer.valueOf(i2));
        }
        return hashMap;
    }

    protected RealMatrix buildMatrix(Map<V, Integer> map) {
        RealMatrix createRealIdentityMatrix = MatrixUtils.createRealIdentityMatrix(this.graph.vertexSet().size());
        for (E e : this.graph.edgeSet()) {
            int intValue = map.get(this.graph.getEdgeSource(e)).intValue();
            int intValue2 = map.get(this.graph.getEdgeTarget(e)).intValue();
            if (intValue != intValue2) {
                double edgeWeight = this.graph.getEdgeWeight(e);
                createRealIdentityMatrix.setEntry(intValue, intValue2, edgeWeight);
                createRealIdentityMatrix.setEntry(intValue2, intValue, edgeWeight);
            }
        }
        return createRealIdentityMatrix;
    }

    protected void normalize() {
        this.matrix.walkInOptimizedOrder(new NormalizeVisitor(this.ones.multiply(this.matrix)));
    }

    protected void expand() {
        this.matrix = this.matrix.power(this.e);
    }

    protected void inflate() {
        normalize();
        this.matrix.walkInOptimizedOrder(this.inflateVisitor);
    }
}
