package edu.umd.hooka.alignment.hmm;

import edu.umd.hooka.Alignment;
import edu.umd.hooka.AlignmentPosteriorGrid;
import edu.umd.hooka.Array2D;
import edu.umd.hooka.PhrasePair;
import edu.umd.hooka.alignment.CrossEntropyCounters;
import edu.umd.hooka.alignment.PartialCountContainer;
import edu.umd.hooka.alignment.PerplexityReporter;
import edu.umd.hooka.alignment.ZeroProbabilityException;
import edu.umd.hooka.alignment.model1.Model1;
import edu.umd.hooka.ttables.TTable;
import java.io.IOException;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;

/* loaded from: input_file:edu/umd/hooka/alignment/hmm/HMM.class */
public class HMM extends Model1 {
    public static final IntWritable ACOUNT_VOC_ID = new IntWritable(999999);
    static final int MAX_LENGTH = 500;
    static final float THRESH = 0.5f;
    Array2D emission;
    IntArray2D e_coords;
    IntArray2D e_words;
    Array2D transition;
    IntArray2D transition_coords;
    Array2D alphas;
    Array2D betas;
    Array2D viterbi;
    IntArray2D backtrace;
    ATable amodel;
    ATable acounts;
    int l;
    int m;
    AlignmentPosteriorGrid m1_post;

    public void setModel1Posteriors(AlignmentPosteriorGrid alignmentPosteriorGrid) {
        this.m1_post = alignmentPosteriorGrid;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public HMM(TTable tTable, ATable aTable, boolean z) {
        super(tTable, z);
        this.emission = new Array2D(250000);
        this.e_coords = new IntArray2D(250000);
        this.e_words = new IntArray2D(250000);
        this.transition = new Array2D(250000);
        this.transition_coords = new IntArray2D(250000);
        this.alphas = new Array2D(250000);
        this.betas = new Array2D(250000);
        this.viterbi = new Array2D(250000);
        this.backtrace = new IntArray2D(250000);
        this.l = -1;
        this.m = -1;
        this.m1_post = null;
        this.amodel = aTable;
        this.acounts = (ATable) this.amodel.clone();
        this.acounts.clear();
    }

    public HMM(TTable tTable, ATable aTable) {
        super(tTable, false);
        this.emission = new Array2D(250000);
        this.e_coords = new IntArray2D(250000);
        this.e_words = new IntArray2D(250000);
        this.transition = new Array2D(250000);
        this.transition_coords = new IntArray2D(250000);
        this.alphas = new Array2D(250000);
        this.betas = new Array2D(250000);
        this.viterbi = new Array2D(250000);
        this.backtrace = new IntArray2D(250000);
        this.l = -1;
        this.m = -1;
        this.m1_post = null;
        this.amodel = aTable;
        this.acounts = (ATable) this.amodel.clone();
        this.acounts.clear();
    }

    @Override // edu.umd.hooka.alignment.model1.Model1Base, edu.umd.hooka.alignment.AlignmentModel
    public void writePartialCounts(OutputCollector<IntWritable, PartialCountContainer> outputCollector) throws IOException {
        super.writePartialCounts(outputCollector);
        PartialCountContainer partialCountContainer = new PartialCountContainer();
        partialCountContainer.setContent(this.acounts);
        outputCollector.collect(ACOUNT_VOC_ID, partialCountContainer);
        this.acounts.clear();
    }

    public void buildHMMTables(PhrasePair phrasePair) {
        int[] words = phrasePair.getE().getWords();
        int[] words2 = phrasePair.getF().getWords();
        this.l = words.length;
        this.m = words2.length;
        this.emission.resize(this.m + 1, this.l + 1);
        this.e_coords.resize(this.m + 1, this.l + 1);
        this.e_words.resize(this.m + 1, this.l + 1);
        this.e_words.fill(-1);
        this.e_coords.fill(-1);
        for (int i = 1; i <= this.l; i++) {
            int i2 = words[i - 1];
            for (int i3 = 1; i3 <= this.m; i3++) {
                int i4 = words2[i3 - 1];
                this.e_coords.set(i3, i, i);
                this.emission.set(i3, i, this.tmodel.get(i2, i4));
                this.e_words.set(i3, i, i - 1);
            }
        }
        this.transition.resize(this.l + 1, this.l + 1);
        this.transition_coords.resize(this.l + 1, this.l + 1);
        this.transition_coords.fill(-1);
        for (int i5 = 0; i5 <= this.l; i5++) {
            for (int i6 = 1; i6 <= this.l; i6++) {
                this.transition_coords.set(i5, i6, this.amodel.getCoord(i6 - i5, (char) this.l));
                this.transition.set(i5, i6, this.amodel.get(i6 - i5, (char) this.l));
            }
        }
    }

    public final int getNumStates() {
        return this.transition.getSize2();
    }

    public final float getTransitionProb(int i, int i2) {
        return this.transition.get(i, i2);
    }

    public final float getEmissionProb(int i, int i2) {
        return this.emission.get(i, i2);
    }

    public final void addPartialJumpCountsToATable(ATable aTable) {
        aTable.plusEquals(this.acounts);
    }

    @Override // edu.umd.hooka.alignment.model1.Model1, edu.umd.hooka.alignment.AlignmentModel
    public void processTrainingInstance(PhrasePair phrasePair, Reporter reporter) {
        if (phrasePair.getE().size() < this.amodel.getMaxDist() - 1 && phrasePair.getF().size() < this.amodel.getMaxDist() - 1 && phrasePair.getE().size() != 0 && phrasePair.getF().size() != 0) {
            buildHMMTables(phrasePair);
            float baumWelch = baumWelch(phrasePair, null);
            if (reporter != null) {
                reporter.incrCounter(CrossEntropyCounters.LOGPROB, -baumWelch);
                reporter.incrCounter(CrossEntropyCounters.WORDCOUNT, phrasePair.getF().size());
            }
        }
    }

    public final float baumWelch(PhrasePair phrasePair, AlignmentPosteriorGrid alignmentPosteriorGrid) {
        initializeCountTableForSentencePair(phrasePair);
        int length = phrasePair.getF().getWords().length + 1;
        int numStates = getNumStates();
        int length2 = phrasePair.getE().getWords().length;
        float[] fArr = new float[length];
        this.alphas.resize(length + 1, getNumStates());
        this.betas.resize(length + 1, getNumStates());
        this.alphas.set(0, 0, 1.0f);
        fArr[0] = 1.0f;
        Alignment alignPosteriorThreshold = this.m1_post != null ? this.m1_post.alignPosteriorThreshold(THRESH) : null;
        for (int i = 1; i < length; i++) {
            int i2 = 0;
            while (i2 < numStates) {
                float f = 0.0f;
                float f2 = 1.0f;
                float f3 = 0.0f;
                boolean z = false;
                if (alignPosteriorThreshold != null && alignPosteriorThreshold.isFAligned(i - 1)) {
                    float f4 = 0.0f;
                    z = true;
                    for (int i3 = 0; i3 < length2; i3++) {
                        if (alignPosteriorThreshold.aligned(i - 1, i3)) {
                            f4 = this.m1_post.getAlignmentPointPosterior(i - 1, i3 + 1);
                        }
                    }
                    f2 = (float) Math.sqrt(f4);
                    f3 = 1.0f - f2;
                }
                for (int i4 = 0; i4 < numStates; i4++) {
                    float transitionProb = getTransitionProb(i4, i2);
                    if (z) {
                        transitionProb = (i2 > length2 || i2 <= 0 || !alignPosteriorThreshold.aligned(i - 1, i2 - 1)) ? transitionProb * f3 : f2;
                    }
                    f += this.alphas.get(i - 1, i4) * transitionProb;
                }
                this.alphas.set(i, i2, f * getEmissionProb(i, i2));
                i2++;
            }
            try {
                fArr[i] = this.alphas.normalizeColumn(i);
            } catch (ZeroProbabilityException e) {
                notifyUnalignablePair(phrasePair, e.getMessage());
                return 0.0f;
            }
        }
        for (int i5 = 1; i5 < numStates; i5++) {
            this.betas.set(length - 1, i5, 1.0f);
        }
        for (int i6 = length - 2; i6 >= 1; i6--) {
            int i7 = 0;
            while (i7 < numStates) {
                float f5 = 0.0f;
                float f6 = 1.0f;
                float f7 = 0.0f;
                boolean z2 = false;
                if (alignPosteriorThreshold != null && alignPosteriorThreshold.isFAligned(i6 - 1)) {
                    float f8 = 0.0f;
                    z2 = true;
                    for (int i8 = 0; i8 < length2; i8++) {
                        if (alignPosteriorThreshold.aligned(i6 - 1, i8)) {
                            f8 = this.m1_post.getAlignmentPointPosterior(i6 - 1, i8 + 1);
                        }
                    }
                    f6 = (float) Math.sqrt(f8);
                    f7 = 1.0f - f6;
                }
                for (int i9 = 0; i9 < numStates; i9++) {
                    float transitionProb2 = getTransitionProb(i7, i9);
                    if (z2) {
                        transitionProb2 = (i7 > length2 || i7 <= 0 || !alignPosteriorThreshold.aligned(i6 - 1, i7 - 1)) ? transitionProb2 * f7 : f6;
                    }
                    f5 += this.betas.get(i6 + 1, i9) * transitionProb2 * getEmissionProb(i6 + 1, i9);
                }
                this.betas.set(i6, i7, f5 / fArr[i6]);
                i7++;
            }
        }
        float[] fArr2 = new float[length];
        for (int i10 = 1; i10 < length; i10++) {
            float f9 = 0.0f;
            for (int i11 = 0; i11 < numStates; i11++) {
                f9 += this.betas.get(i10, i11) * this.alphas.get(i10, i11);
            }
            fArr2[i10] = f9;
            for (int i12 = 0; i12 < numStates; i12++) {
                int i13 = this.e_coords.get(i10, i12);
                if (i13 != -1) {
                    float f10 = (this.betas.get(i10, i12) * this.alphas.get(i10, i12)) / f9;
                    if (alignmentPosteriorGrid != null) {
                        int i14 = i12 <= length2 ? i12 : 0;
                        if (i12 != 0) {
                            alignmentPosteriorGrid.setAlignmentPointPosterior(i10 - 1, i14, alignmentPosteriorGrid.getAlignmentPointPosterior(i10 - 1, i14) + f10);
                        }
                    } else {
                        try {
                            addTranslationCount(i13, i10 - 1, f10);
                        } catch (Exception e2) {
                            throw new RuntimeException("J=" + length + ", numStates=" + numStates + ": Failed to add (" + i13 + "," + (i10 - 1) + ") += " + f10 + " s=" + i12 + " pp=" + phrasePair + "\n E:\n" + this.e_coords);
                        }
                    }
                }
            }
        }
        if (alignmentPosteriorGrid == null) {
            for (int i15 = 1; i15 < length - 1; i15++) {
                for (int i16 = 0; i16 < numStates; i16++) {
                    int i17 = 0;
                    while (i17 < numStates) {
                        int i18 = this.transition_coords.get(i16, i17);
                        if (i18 != -1) {
                            float f11 = 1.0f;
                            float f12 = 0.0f;
                            boolean z3 = false;
                            if (alignPosteriorThreshold != null && alignPosteriorThreshold.isFAligned(i15 - 1)) {
                                float f13 = 0.0f;
                                z3 = true;
                                for (int i19 = 0; i19 < length2; i19++) {
                                    if (alignPosteriorThreshold.aligned(i15 - 1, i19)) {
                                        f13 = this.m1_post.getAlignmentPointPosterior(i15 - 1, i19 + 1);
                                    }
                                }
                                f11 = (float) Math.sqrt(f13);
                                f12 = 1.0f - f11;
                            }
                            float transitionProb3 = getTransitionProb(i16, i17);
                            if (z3) {
                                transitionProb3 = (i17 > length2 || i17 <= 0 || !alignPosteriorThreshold.aligned(i15 - 1, i17 - 1)) ? transitionProb3 * f12 : f11;
                            }
                            if (!z3) {
                                this.acounts.add(i18, (char) length2, ((((this.alphas.get(i15, i16) * transitionProb3) * this.emission.get(i15 + 1, i17)) / fArr[i15 + 1]) * this.betas.get(i15 + 1, i17)) / fArr2[i15 + 1]);
                            }
                        }
                        i17++;
                    }
                }
            }
        }
        float f14 = 0.0f;
        for (float f15 : fArr) {
            f14 = (float) (f14 + Math.log(f15));
        }
        return f14;
    }

    @Override // edu.umd.hooka.alignment.model1.Model1, edu.umd.hooka.alignment.AlignmentModel
    public AlignmentPosteriorGrid computeAlignmentPosteriors(PhrasePair phrasePair) {
        AlignmentPosteriorGrid alignmentPosteriorGrid = new AlignmentPosteriorGrid(phrasePair);
        buildHMMTables(phrasePair);
        baumWelch(phrasePair, alignmentPosteriorGrid);
        return alignmentPosteriorGrid;
    }

    @Override // edu.umd.hooka.alignment.model1.Model1, edu.umd.hooka.alignment.AlignmentModel
    public Alignment viterbiAlign(PhrasePair phrasePair, PerplexityReporter perplexityReporter) {
        buildHMMTables(phrasePair);
        Alignment alignment = new Alignment(phrasePair.getF().size(), phrasePair.getE().size());
        int size = phrasePair.getF().size() + 1;
        int numStates = getNumStates();
        this.viterbi.resize(size, getNumStates());
        this.backtrace.resize(size, getNumStates());
        this.viterbi.fill(Float.NEGATIVE_INFINITY);
        this.viterbi.set(0, 0, 0.0f);
        int length = phrasePair.getE().getWords().length;
        Alignment alignPosteriorThreshold = this.m1_post != null ? this.m1_post.alignPosteriorThreshold(THRESH) : null;
        for (int i = 1; i < size; i++) {
            boolean z = false;
            int i2 = 1;
            while (i2 < numStates) {
                float f = Float.NEGATIVE_INFINITY;
                int i3 = -1;
                double log = Math.log(this.emission.get(i, i2));
                if (log != Double.NEGATIVE_INFINITY) {
                    for (int i4 = 0; i4 < numStates; i4++) {
                        float f2 = 1.0f;
                        float f3 = 0.0f;
                        boolean z2 = false;
                        if (alignPosteriorThreshold != null && alignPosteriorThreshold.isFAligned(i - 1)) {
                            float f4 = 0.0f;
                            z2 = true;
                            for (int i5 = 0; i5 < length; i5++) {
                                if (alignPosteriorThreshold.aligned(i - 1, i5)) {
                                    f4 = this.m1_post.getAlignmentPointPosterior(i - 1, i5 + 1);
                                }
                            }
                            f2 = (float) Math.sqrt(f4);
                            f3 = 1.0f - f2;
                        }
                        float transitionProb = getTransitionProb(i4, i2);
                        if (z2) {
                            transitionProb = (i2 > this.l || i2 <= 0 || !alignPosteriorThreshold.aligned(i - 1, i2 - 1)) ? transitionProb * f3 : f2;
                        }
                        float log2 = (float) (this.viterbi.get(i - 1, i4) + Math.log(transitionProb) + log);
                        if (log2 > f) {
                            f = log2;
                            i3 = i4;
                        }
                    }
                    this.viterbi.set(i, i2, f);
                    if (f != Float.NEGATIVE_INFINITY) {
                        z = true;
                    }
                    this.backtrace.set(i, i2, i3);
                }
                i2++;
            }
            if (!z) {
                float f5 = Float.NEGATIVE_INFINITY;
                int i6 = -1;
                for (int i7 = 1; i7 < numStates; i7++) {
                    if (this.viterbi.get(i - 1, i7) > f5) {
                        f5 = this.viterbi.get(i - 1, i7);
                        i6 = i7;
                    }
                }
                for (int i8 = 1; i8 < numStates; i8++) {
                    this.viterbi.set(i, i8, 0.0f);
                    this.backtrace.set(i, i8, i6);
                }
            }
        }
        float f6 = Float.NEGATIVE_INFINITY;
        int i9 = -1;
        for (int i10 = 1; i10 < numStates; i10++) {
            if (this.viterbi.get(size - 1, i10) > f6) {
                f6 = this.viterbi.get(size - 1, i10);
                i9 = i10;
            }
        }
        perplexityReporter.addFactor(f6, size - 1);
        int i11 = i9;
        for (int i12 = size - 1; i12 > 0; i12--) {
            if (i11 <= 0) {
                throw new ZeroProbabilityException("  Error f=" + i12 + " e=" + i11 + "  sentence + \n" + this.viterbi + "\n" + this.emission + "\n" + this.transition + "\n" + this.backtrace);
            }
            if (this.viterbi.get(i12, i11) < 0.0d) {
                try {
                    int i13 = i12 - 1;
                    int i14 = this.e_words.get(i12, i11);
                    if (i14 >= 0) {
                        alignment.align(i13, i14);
                    }
                } catch (RuntimeException e) {
                    throw new RuntimeException("Caught " + e + "\nvit(f,e)=" + this.viterbi.get(i12, i11) + "  size(f,e)=" + phrasePair.getF().size() + "," + phrasePair.getE().size() + " Error f=" + i12 + " e=" + i11 + "  sentence + \n" + this.viterbi + "\n" + this.emission + "\n" + this.transition + "\n" + this.backtrace + "\n" + this.e_words);
                }
            }
            i11 = this.backtrace.get(i12, i11);
        }
        return alignment;
    }
}
