package com.johnsnowlabs.nlp.annotators.parser.typdep;

import com.johnsnowlabs.nlp.annotators.parser.typdep.util.Utils;
import java.util.ArrayList;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/johnsnowlabs/nlp/annotators/parser/typdep/LowRankTensor.class */
public class LowRankTensor {
    private int dim;
    private int rank;
    private int[] N;
    private static final int MAX_ITER = 1000;
    private Logger logger = LoggerFactory.getLogger("TypedDependencyParser");
    private ArrayList<MatEntry> list = new ArrayList<>();

    /* JADX INFO: Access modifiers changed from: package-private */
    public LowRankTensor(int[] iArr, int i) {
        this.N = (int[]) iArr.clone();
        this.dim = iArr.length;
        this.rank = i;
    }

    public void add(int[] iArr, float f) {
        this.list.add(new MatEntry(iArr, f));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void decompose(ArrayList<float[][]> arrayList) {
        ArrayList arrayList2 = new ArrayList();
        Iterator<float[][]> it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(new float[this.rank][it.next().length]);
        }
        for (int i = 0; i < this.rank; i++) {
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < this.dim; i2++) {
                arrayList3.add(Utils.getRandomUnitVector(this.N[i2]));
            }
            double d = 0.0d;
            double d2 = Double.POSITIVE_INFINITY;
            int i3 = 0;
            while (i3 < MAX_ITER) {
                for (int i4 = 0; i4 < this.dim; i4++) {
                    float[] fArr = (float[]) arrayList3.get(i4);
                    for (int i5 = 0; i5 < this.N[i4]; i5++) {
                        fArr[i5] = 0.0f;
                    }
                    Iterator<MatEntry> it2 = this.list.iterator();
                    while (it2.hasNext()) {
                        MatEntry next = it2.next();
                        double d3 = next.val;
                        for (int i6 = 0; i6 < this.dim; i6++) {
                            if (i6 != i4) {
                                d3 *= ((float[]) arrayList3.get(i6))[next.x[i6]];
                            }
                        }
                        fArr[next.x[i4]] = (float) (fArr[r1] + d3);
                    }
                    for (int i7 = 0; i7 < i; i7++) {
                        double d4 = 1.0d;
                        for (int i8 = 0; i8 < this.dim; i8++) {
                            if (i8 != i4) {
                                d4 *= Utils.dot((float[]) arrayList3.get(i8), ((float[][]) arrayList2.get(i8))[i7]);
                            }
                        }
                        for (int i9 = 0; i9 < this.N[i4]; i9++) {
                            fArr[i9] = (float) (fArr[r1] - (d4 * ((float[][]) arrayList2.get(i4))[i7][i9]));
                        }
                    }
                    if (i4 < this.dim - 1) {
                        Utils.normalize(fArr);
                    } else {
                        d = Math.sqrt(Utils.squaredSum(fArr));
                    }
                }
                if (d2 != Double.POSITIVE_INFINITY && Math.abs(d - d2) < 1.0E-6d) {
                    break;
                }
                d2 = d;
                i3++;
            }
            if (i3 >= MAX_ITER) {
                this.logger.warn("Power method didn't converge.rankFirstOrderTensor=%d sigma=%f%n", Integer.valueOf(i), Double.valueOf(d));
            }
            if (Math.abs(d) <= 1.0E-6d && this.logger.isDebugEnabled()) {
                this.logger.warn(String.format("Power method has nearly-zero sigma. rankFirstOrderTensor=%d%n", Integer.valueOf(i)));
            }
            if (this.logger.isDebugEnabled()) {
                this.logger.debug(String.format("norm: %.2f", Double.valueOf(d)));
            }
            for (int i10 = 0; i10 < this.dim; i10++) {
                ((float[][]) arrayList2.get(i10))[i] = (float[]) arrayList3.get(i10);
            }
        }
        for (int i11 = 0; i11 < arrayList.size(); i11++) {
            float[][] fArr2 = arrayList.get(i11);
            float[][] fArr3 = (float[][]) arrayList2.get(i11);
            int length = fArr2.length;
            for (int i12 = 0; i12 < length; i12++) {
                for (int i13 = 0; i13 < this.rank; i13++) {
                    fArr2[i12][i13] = fArr3[i13][i12];
                }
            }
        }
    }
}
