package com.johnsnowlabs.nlp.annotators.parser.typdep;

import com.johnsnowlabs.nlp.annotators.parser.typdep.feature.SyntacticFeatureFactory;
import com.johnsnowlabs.nlp.annotators.parser.typdep.util.FeatureVector;
import com.johnsnowlabs.nlp.annotators.parser.typdep.util.ScoreCollector;
import java.util.Arrays;

/* loaded from: input_file:com/johnsnowlabs/nlp/annotators/parser/typdep/LocalFeatureData.class */
public class LocalFeatureData {
    private DependencyInstance dependencyInstance;
    private DependencyPipe pipe;
    private SyntacticFeatureFactory synFactory;
    private Options options;
    private Parameters parameters;
    private final int sentenceLength;
    private final int numberOfLabelTypes;
    private final float gammaL;
    FeatureVector[] wordFvs;
    float[][] wpU;
    float[][] wpV;
    float[][] wpU2;
    float[][] wpV2;
    float[][] wpW2;
    private float[][] scoresOrProbabilities;
    private float[][][] labelScores;

    /* JADX INFO: Access modifiers changed from: package-private */
    public LocalFeatureData(DependencyInstance dependencyInstance, TypedDependencyParser typedDependencyParser) {
        this.dependencyInstance = dependencyInstance;
        this.pipe = typedDependencyParser.getDependencyPipe();
        this.synFactory = this.pipe.getSynFactory();
        this.options = typedDependencyParser.getOptions();
        this.parameters = typedDependencyParser.getParameters();
        this.sentenceLength = dependencyInstance.getLength();
        this.numberOfLabelTypes = this.pipe.getTypes().length;
        int i = this.options.rankFirstOrderTensor;
        int i2 = this.options.rankSecondOrderTensor;
        this.gammaL = this.options.gammaLabel;
        this.wordFvs = new FeatureVector[this.sentenceLength];
        this.wpU = new float[this.sentenceLength][i];
        this.wpV = new float[this.sentenceLength][i];
        this.wpU2 = new float[this.sentenceLength][i2];
        this.wpV2 = new float[this.sentenceLength][i2];
        this.wpW2 = new float[this.sentenceLength][i2];
        this.scoresOrProbabilities = new float[this.sentenceLength][this.numberOfLabelTypes];
        this.labelScores = new float[this.sentenceLength][this.numberOfLabelTypes][this.numberOfLabelTypes];
        for (int i3 = 0; i3 < this.sentenceLength; i3++) {
            this.wordFvs[i3] = this.synFactory.createWordFeatures(dependencyInstance, i3);
            this.parameters.projectU(this.wordFvs[i3], this.wpU[i3]);
            this.parameters.projectV(this.wordFvs[i3], this.wpV[i3]);
            this.parameters.projectU2(this.wordFvs[i3], this.wpU2 != null ? this.wpU2[i3] : new float[0]);
            this.parameters.projectV2(this.wordFvs[i3], this.wpV2 != null ? this.wpV2[i3] : new float[0]);
            this.parameters.projectW2(this.wordFvs[i3], this.wpW2 != null ? this.wpW2[i3] : new float[0]);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public FeatureVector getLabeledFeatureDifference(DependencyInstance dependencyInstance, int[] iArr, int[] iArr2) {
        FeatureVector featureVector = new FeatureVector();
        int[] heads = dependencyInstance.getHeads();
        int[] dependencyLabelIds = dependencyInstance.getDependencyLabelIds();
        for (int i = 1; i < this.dependencyInstance.getLength(); i++) {
            int i2 = heads[i];
            if (dependencyLabelIds[i] != iArr2[i]) {
                featureVector.addEntries(getLabelFeature(heads, dependencyLabelIds, i, 1));
                featureVector.addEntries(getLabelFeature(iArr, iArr2, i, 1), -1.0f);
            }
            if (dependencyLabelIds[i] != iArr2[i] || dependencyLabelIds[i2] != iArr2[i2]) {
                featureVector.addEntries(getLabelFeature(heads, dependencyLabelIds, i, 2));
                featureVector.addEntries(getLabelFeature(iArr, iArr2, i, 2), -1.0f);
            }
        }
        return featureVector;
    }

    private FeatureVector getLabelFeature(int[] iArr, int[] iArr2, int i, int i2) {
        FeatureVector featureVector = new FeatureVector();
        this.synFactory.createLabelFeatures(featureVector, this.dependencyInstance, iArr, iArr2, i, i2);
        return featureVector;
    }

    private void predictLabelsDP(int[] iArr, int[] iArr2, boolean z, DependencyArcList dependencyArcList) {
        int i = z ? 0 : 1;
        int i2 = 1;
        while (i2 < this.sentenceLength) {
            int i3 = iArr[i2];
            int i4 = i3 > i2 ? 1 : 2;
            int i5 = iArr[i3];
            int i6 = i5 > i3 ? 1 : 2;
            int i7 = i;
            while (i7 < this.numberOfLabelTypes) {
                int[] xPosTagIds = this.dependencyInstance.getXPosTagIds();
                if (this.pipe.getPruneLabel()[xPosTagIds[i3]][xPosTagIds[i2]][i7]) {
                    iArr2[i2] = i7;
                    float labelScoreTheta = this.gammaL > 0.0f ? 0.0f + (this.gammaL * getLabelScoreTheta(iArr, iArr2, i2, 1)) : 0.0f;
                    if (this.gammaL < 1.0f) {
                        labelScoreTheta += (1.0f - this.gammaL) * this.parameters.dotProductL(this.wpU[i3], this.wpV[i2], i7, i4);
                    }
                    for (int i8 = i; i8 < this.numberOfLabelTypes; i8++) {
                        if (i5 != -1) {
                            if (this.pipe.getPruneLabel()[xPosTagIds[i5]][xPosTagIds[i3]][i8]) {
                                iArr2[i3] = i8;
                                r27 = this.gammaL > 0.0f ? 0.0f + (this.gammaL * getLabelScoreTheta(iArr, iArr2, i2, 2)) : 0.0f;
                                if (this.gammaL < 1.0f) {
                                    r27 += (1.0f - this.gammaL) * this.parameters.dotProduct2L(this.wpU2[i5], this.wpV2[i3], this.wpW2[i2], i8, i7, i6, i4);
                                }
                            } else {
                                r27 = Float.NEGATIVE_INFINITY;
                            }
                        }
                        this.labelScores[i2][i7][i8] = labelScoreTheta + r27 + ((!z || this.dependencyInstance.getDependencyLabelIds()[i2] == i7) ? 0.0f : 1.0f);
                    }
                } else {
                    Arrays.fill(this.labelScores[i2][i7], Float.NEGATIVE_INFINITY);
                }
                i7++;
            }
            i2++;
        }
        treeDP(0, dependencyArcList, i);
        iArr2[0] = this.dependencyInstance.getDependencyLabelIds()[0];
        computeDependencyLabels(0, dependencyArcList, iArr2, i);
    }

    private float getLabelScoreTheta(int[] iArr, int[] iArr2, int i, int i2) {
        ScoreCollector scoreCollector = new ScoreCollector(this.parameters.getParamsL());
        this.synFactory.createLabelFeatures(scoreCollector, this.dependencyInstance, iArr, iArr2, i, i2);
        return scoreCollector.getScore();
    }

    private void treeDP(int i, DependencyArcList dependencyArcList, int i2) {
        Arrays.fill(this.scoresOrProbabilities[i], 0.0f);
        int startIndex = dependencyArcList.startIndex(i);
        int endIndex = dependencyArcList.endIndex(i);
        for (int i3 = startIndex; i3 < endIndex; i3++) {
            int i4 = dependencyArcList.get(i3);
            treeDP(i4, dependencyArcList, i2);
            for (int i5 = i2; i5 < this.numberOfLabelTypes; i5++) {
                float f = this.scoresOrProbabilities[i4][i2] + this.labelScores[i4][i2][i5];
                for (int i6 = i2 + 1; i6 < this.numberOfLabelTypes; i6++) {
                    float f2 = this.scoresOrProbabilities[i4][i6] + this.labelScores[i4][i6][i5];
                    if (f2 > f) {
                        f = f2;
                    }
                }
                float[] fArr = this.scoresOrProbabilities[i];
                int i7 = i5;
                fArr[i7] = fArr[i7] + f;
            }
        }
    }

    private void computeDependencyLabels(int i, DependencyArcList dependencyArcList, int[] iArr, int i2) {
        int i3 = iArr[i];
        int startIndex = dependencyArcList.startIndex(i);
        int endIndex = dependencyArcList.endIndex(i);
        for (int i4 = startIndex; i4 < endIndex; i4++) {
            int i5 = dependencyArcList.get(i4);
            int i6 = 0;
            float f = Float.NEGATIVE_INFINITY;
            for (int i7 = i2; i7 < this.numberOfLabelTypes; i7++) {
                float f2 = this.scoresOrProbabilities[i5][i7] + this.labelScores[i5][i7][i3];
                if (f2 > f) {
                    f = f2;
                    i6 = i7;
                }
            }
            if (f == Float.NEGATIVE_INFINITY) {
                i6 = iArr[i5];
            }
            iArr[i5] = i6;
            computeDependencyLabels(i5, dependencyArcList, iArr, i2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void predictLabels(int[] iArr, int[] iArr2, boolean z) {
        predictLabelsDP(iArr, iArr2, z, new DependencyArcList(iArr));
    }
}
