package com.joliciel.talismane.machineLearning.perceptron;

import com.joliciel.talismane.machineLearning.ClassificationEvent;
import com.joliciel.talismane.machineLearning.ClassificationEventStream;
import com.joliciel.talismane.machineLearning.ClassificationModel;
import com.joliciel.talismane.machineLearning.ClassificationModelTrainer;
import com.joliciel.talismane.machineLearning.MachineLearningModel;
import com.joliciel.talismane.utils.LogUtils;
import com.typesafe.config.Config;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/joliciel/talismane/machineLearning/perceptron/PerceptronClassificationModelTrainer.class */
public class PerceptronClassificationModelTrainer implements ClassificationModelTrainer {
    private static final Logger LOG = LoggerFactory.getLogger(PerceptronClassificationModelTrainer.class);
    private int iterations;
    private int cutoff;
    private double tolerance;
    private PerceptronScoring scoring;
    private double[][] totalFeatureWeights;
    private PerceptronModelParameters params;
    private File eventFile;
    private PerceptronDecisionMaker decisionMaker;
    private Map<String, List<String>> descriptors;
    private ClassificationEventStream corpusEventStream;
    private PerceptronModelTrainerObserver observer;
    private List<Integer> observationPoints;
    private boolean averageAtIntervals = false;
    private Config config;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/joliciel/talismane/machineLearning/perceptron/PerceptronClassificationModelTrainer$PerceptronEvent.class */
    public static final class PerceptronEvent {
        List<Integer> featureIndexes;
        List<Double> featureValues;
        int outcomeIndex;

        public PerceptronEvent(ClassificationEvent classificationEvent, PerceptronModelParameters perceptronModelParameters) {
            this.featureIndexes = new ArrayList();
            this.featureValues = new ArrayList();
            perceptronModelParameters.prepareData(classificationEvent.getFeatureResults(), this.featureIndexes, this.featureValues, true);
            this.outcomeIndex = perceptronModelParameters.getOrCreateOutcomeIndex(classificationEvent.getClassification());
        }

        public PerceptronEvent(String str) {
            String[] split = str.split(" ");
            this.outcomeIndex = Integer.parseInt(split[0]);
            int length = (split.length - 1) / 2;
            this.featureIndexes = new ArrayList(length);
            this.featureValues = new ArrayList(length);
            int i = 1;
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = i;
                int i4 = i + 1;
                this.featureIndexes.add(Integer.valueOf(Integer.parseInt(split[i3])));
                i = i4 + 1;
                this.featureValues.add(Double.valueOf(Double.parseDouble(split[i4])));
            }
        }

        public PerceptronEvent(PerceptronEvent perceptronEvent, int[] iArr) {
            this.featureIndexes = new ArrayList();
            this.featureValues = new ArrayList();
            int i = 0;
            Iterator<Integer> it = perceptronEvent.featureIndexes.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                if (iArr[intValue] >= 0) {
                    this.featureIndexes.add(Integer.valueOf(iArr[intValue]));
                    this.featureValues.add(perceptronEvent.featureValues.get(i));
                }
                i++;
            }
            this.outcomeIndex = perceptronEvent.outcomeIndex;
        }

        public List<Integer> getFeatureIndexes() {
            return this.featureIndexes;
        }

        public List<Double> getFeatureValues() {
            return this.featureValues;
        }

        public int getOutcomeIndex() {
            return this.outcomeIndex;
        }

        public void write(Writer writer) throws IOException {
            writer.write("" + this.outcomeIndex);
            for (int i = 0; i < this.featureIndexes.size(); i++) {
                writer.write(" ");
                writer.write("" + this.featureIndexes.get(i));
                writer.write(" ");
                writer.write("" + this.featureValues.get(i));
            }
            writer.write("\n");
            writer.flush();
        }
    }

    /* loaded from: input_file:com/joliciel/talismane/machineLearning/perceptron/PerceptronClassificationModelTrainer$PerceptronModelParameter.class */
    public enum PerceptronModelParameter {
        Iterations(Integer.class),
        Cutoff(Integer.class),
        Tolerance(Double.class),
        AverageAtIntervals(Boolean.class);

        private Class<?> parameterType;

        PerceptronModelParameter(Class cls) {
            this.parameterType = cls;
        }

        public Class<?> getParameterType() {
            return this.parameterType;
        }
    }

    /* JADX WARN: Finally extract failed */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Removed duplicated region for block: B:58:0x0265 A[Catch: Throwable -> 0x02b9, all -> 0x02c2, Throwable -> 0x031b, all -> 0x0324, IOException -> 0x036c, LOOP:6: B:55:0x025d->B:58:0x0265, LOOP_END, TryCatch #0 {, blocks: (B:54:0x022b, B:56:0x025d, B:58:0x0265, B:60:0x028e, B:66:0x029d, B:64:0x02b1, B:69:0x02a7, B:70:0x02ea, B:97:0x02c1, B:88:0x02ce, B:86:0x02e2, B:91:0x02d8, B:93:0x02e9), top: B:53:0x022b }] */
    /* JADX WARN: Removed duplicated region for block: B:62:0x0298  */
    /* JADX WARN: Removed duplicated region for block: B:72:0x02fa  */
    /* JADX WARN: Type inference failed for: r0v34 */
    /* JADX WARN: Type inference failed for: r0v36 */
    /* JADX WARN: Type inference failed for: r0v38, types: [java.util.Scanner] */
    /* JADX WARN: Type inference failed for: r0v39, types: [java.util.Scanner] */
    /* JADX WARN: Type inference failed for: r15v0, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r15v1 */
    /* JADX WARN: Type inference failed for: r15v2 */
    /* JADX WARN: Type inference failed for: r16v5 */
    /* JADX WARN: Type inference failed for: r16v6, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r17v4 */
    /* JADX WARN: Type inference failed for: r17v5, types: [int] */
    /* JADX WARN: Type inference failed for: r17v7, types: [com.joliciel.talismane.machineLearning.perceptron.PerceptronClassificationModelTrainer$PerceptronEvent] */
    /* JADX WARN: Type inference failed for: r17v8 */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    void prepareData(com.joliciel.talismane.machineLearning.ClassificationEventStream r11) {
        /*
            Method dump skipped, instructions count: 894
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: com.joliciel.talismane.machineLearning.perceptron.PerceptronClassificationModelTrainer.prepareData(com.joliciel.talismane.machineLearning.ClassificationEventStream):void");
    }

    void train() {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        int i = 0;
        int i2 = 1;
        while (true) {
            try {
                if (i2 > this.iterations) {
                    break;
                }
                LOG.debug("Iteration " + i2);
                int i3 = 0;
                int i4 = 0;
                Scanner scanner = new Scanner(new BufferedReader(new InputStreamReader(new FileInputStream(this.eventFile), "UTF-8")));
                Throwable th = null;
                while (scanner.hasNextLine()) {
                    try {
                        try {
                            PerceptronEvent perceptronEvent = new PerceptronEvent(scanner.nextLine());
                            i4++;
                            double[] predict = this.decisionMaker.predict(perceptronEvent.getFeatureIndexes(), perceptronEvent.getFeatureValues());
                            double d4 = predict[0];
                            int i5 = 0;
                            for (int i6 = 1; i6 < predict.length; i6++) {
                                if (predict[i6] > d4) {
                                    d4 = predict[i6];
                                    i5 = i6;
                                }
                            }
                            int outcomeIndex = perceptronEvent.getOutcomeIndex();
                            if (outcomeIndex != i5) {
                                for (int i7 = 0; i7 < perceptronEvent.getFeatureIndexes().size(); i7++) {
                                    double[] dArr = this.params.getFeatureWeights()[perceptronEvent.getFeatureIndexes().get(i7).intValue()];
                                    dArr[outcomeIndex] = dArr[outcomeIndex] + perceptronEvent.getFeatureValues().get(i7).doubleValue();
                                    int i8 = i5;
                                    dArr[i8] = dArr[i8] - perceptronEvent.getFeatureValues().get(i7).doubleValue();
                                }
                                i3++;
                            }
                        } finally {
                        }
                    } finally {
                    }
                }
                if (scanner != null) {
                    if (0 != 0) {
                        try {
                            scanner.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scanner.close();
                    }
                }
                boolean z = true;
                if (isAverageAtIntervals()) {
                    if (i2 <= 20 || i2 == 25 || i2 == 36 || i2 == 49 || i2 == 64 || i2 == 81 || i2 == 100 || i2 == 121 || i2 == 144 || i2 == 169 || i2 == 196) {
                        z = true;
                        LOG.debug("Averaging at iteration: " + i2);
                    } else {
                        z = false;
                    }
                }
                if (z) {
                    for (int i9 = 0; i9 < this.params.getFeatureWeights().length; i9++) {
                        double[] dArr2 = this.totalFeatureWeights[i9];
                        double[] dArr3 = this.params.getFeatureWeights()[i9];
                        for (int i10 = 0; i10 < this.params.getOutcomeCount(); i10++) {
                            int i11 = i10;
                            dArr2[i11] = dArr2[i11] + dArr3[i10];
                        }
                    }
                    i++;
                }
                if (this.observer != null && this.observationPoints.contains(Integer.valueOf(i2))) {
                    PerceptronModelParameters m36clone = this.params.m36clone();
                    for (int i12 = 0; i12 < m36clone.getFeatureWeights().length; i12++) {
                        double[] dArr4 = this.totalFeatureWeights[i12];
                        double[] dArr5 = m36clone.getFeatureWeights()[i12];
                        for (int i13 = 0; i13 < m36clone.getOutcomeCount(); i13++) {
                            dArr5[i13] = dArr4[i13] / i;
                        }
                    }
                    this.observer.onNextModel(getModel(m36clone, i2), i2);
                }
                double d5 = (i4 - i3) / i4;
                LOG.debug("Accuracy: " + d5);
                if (Math.abs(d5 - d) < this.tolerance && Math.abs(d5 - d2) < this.tolerance && Math.abs(d5 - d3) < this.tolerance) {
                    LOG.info("Accuracy change < " + this.tolerance + " for 3 iterations: exiting after " + i2 + " iterations");
                    break;
                }
                d3 = d2;
                d2 = d;
                d = d5;
                i2++;
            } catch (IOException e) {
                LogUtils.logError(LOG, e);
                throw new RuntimeException(e);
            }
        }
        for (int i14 = 0; i14 < this.params.getFeatureWeights().length; i14++) {
            double[] dArr6 = this.totalFeatureWeights[i14];
            double[] dArr7 = this.params.getFeatureWeights()[i14];
            for (int i15 = 0; i15 < this.params.getOutcomeCount(); i15++) {
                dArr7[i15] = dArr6[i15] / i;
            }
        }
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setIterations(int i) {
        this.iterations = i;
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public int getCutoff() {
        return this.cutoff;
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public void setCutoff(int i) {
        this.cutoff = i;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public boolean isAverageAtIntervals() {
        return this.averageAtIntervals;
    }

    public void setAverageAtIntervals(boolean z) {
        this.averageAtIntervals = z;
    }

    public void trainModelsWithObserver(ClassificationEventStream classificationEventStream, List<String> list, PerceptronModelTrainerObserver perceptronModelTrainerObserver, List<Integer> list2) {
        HashMap hashMap = new HashMap();
        hashMap.put(MachineLearningModel.FEATURE_DESCRIPTOR_KEY, list);
        trainModelsWithObserver(classificationEventStream, hashMap, perceptronModelTrainerObserver, list2);
    }

    public void trainModelsWithObserver(ClassificationEventStream classificationEventStream, Map<String, List<String>> map, PerceptronModelTrainerObserver perceptronModelTrainerObserver, List<Integer> list) {
        this.params = new PerceptronModelParameters();
        this.decisionMaker = new PerceptronDecisionMaker(this.params, getScoring());
        this.descriptors = map;
        this.observer = perceptronModelTrainerObserver;
        this.observationPoints = list;
        this.corpusEventStream = classificationEventStream;
        prepareData(classificationEventStream);
        train();
        if (this.eventFile != null) {
            this.eventFile.delete();
        }
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public ClassificationModel trainModel(ClassificationEventStream classificationEventStream, List<String> list) {
        HashMap hashMap = new HashMap();
        hashMap.put(MachineLearningModel.FEATURE_DESCRIPTOR_KEY, list);
        return trainModel(classificationEventStream, hashMap);
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public ClassificationModel trainModel(ClassificationEventStream classificationEventStream, Map<String, List<String>> map) {
        this.params = new PerceptronModelParameters();
        this.decisionMaker = new PerceptronDecisionMaker(this.params, getScoring());
        this.descriptors = map;
        this.corpusEventStream = classificationEventStream;
        prepareData(classificationEventStream);
        train();
        ClassificationModel model = getModel(this.params, getIterations());
        if (this.eventFile != null) {
            this.eventFile.delete();
        }
        return model;
    }

    ClassificationModel getModel(PerceptronModelParameters perceptronModelParameters, int i) {
        PerceptronClassificationModel perceptronClassificationModel = new PerceptronClassificationModel(perceptronModelParameters, this.config, this.descriptors);
        perceptronClassificationModel.addModelAttribute("cutoff", Integer.valueOf(getCutoff()));
        perceptronClassificationModel.addModelAttribute("iterations", Integer.valueOf(getIterations()));
        perceptronClassificationModel.addModelAttribute("tolerance", Double.valueOf(getTolerance()));
        perceptronClassificationModel.addModelAttribute("averageAtIntervals", Boolean.valueOf(isAverageAtIntervals()));
        perceptronClassificationModel.addModelAttribute("scoring", getScoring());
        perceptronClassificationModel.getModelAttributes().putAll(this.corpusEventStream.getAttributes());
        return perceptronClassificationModel;
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public void setParameters(Config config) {
        this.config = config;
        Config config2 = config.getConfig("talismane.machineLearning");
        Config config3 = config2.getConfig("perceptron");
        setCutoff(config2.getInt("cutoff"));
        setIterations(config3.getInt("iterations"));
        setTolerance(config3.getDouble("tolerance"));
        setAverageAtIntervals(config3.getBoolean("averageAtIntervals"));
        setScoring(PerceptronScoring.valueOf(config3.getString("scoring")));
    }

    public PerceptronScoring getScoring() {
        return this.scoring;
    }

    public void setScoring(PerceptronScoring perceptronScoring) {
        this.scoring = perceptronScoring;
    }
}
