package com.joliciel.talismane.machineLearning.maxent;

import com.joliciel.talismane.machineLearning.ClassificationEventStream;
import com.joliciel.talismane.machineLearning.ClassificationModel;
import com.joliciel.talismane.machineLearning.ClassificationModelTrainer;
import com.joliciel.talismane.machineLearning.MachineLearningModel;
import com.joliciel.talismane.machineLearning.maxent.custom.GISTrainer;
import com.joliciel.talismane.machineLearning.maxent.custom.TwoPassRealValueDataIndexer;
import com.typesafe.config.Config;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import opennlp.model.DataIndexer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/joliciel/talismane/machineLearning/maxent/MaxentModelTrainer.class */
public class MaxentModelTrainer implements ClassificationModelTrainer {
    private static final Logger LOG = LoggerFactory.getLogger(MaxentModelTrainer.class);
    private int iterations;
    private int cutoff;
    private double sigma;
    private double smoothing;
    private Config config;

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public ClassificationModel trainModel(ClassificationEventStream classificationEventStream, List<String> list) {
        HashMap hashMap = new HashMap();
        hashMap.put(MachineLearningModel.FEATURE_DESCRIPTOR_KEY, list);
        return trainModel(classificationEventStream, hashMap);
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public ClassificationModel trainModel(ClassificationEventStream classificationEventStream, Map<String, List<String>> map) {
        try {
            DataIndexer twoPassRealValueDataIndexer = new TwoPassRealValueDataIndexer(new OpenNLPEventStream(classificationEventStream), this.cutoff);
            GISTrainer gISTrainer = new GISTrainer(true);
            if (getSmoothing() > 0.0d) {
                gISTrainer.setSmoothing(true);
                gISTrainer.setSmoothingObservation(getSmoothing());
            } else if (getSigma() > 0.0d) {
                gISTrainer.setGaussianSigma(getSigma());
            }
            MaximumEntropyModel maximumEntropyModel = new MaximumEntropyModel(gISTrainer.trainModel(this.iterations, twoPassRealValueDataIndexer, this.cutoff), this.config, map);
            maximumEntropyModel.addModelAttribute("cutoff", Integer.valueOf(getCutoff()));
            maximumEntropyModel.addModelAttribute("iterations", Integer.valueOf(getIterations()));
            maximumEntropyModel.addModelAttribute("sigma", Double.valueOf(getSigma()));
            maximumEntropyModel.addModelAttribute("smoothing", Double.valueOf(getSmoothing()));
            maximumEntropyModel.getModelAttributes().putAll(classificationEventStream.getAttributes());
            return maximumEntropyModel;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setIterations(int i) {
        this.iterations = i;
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public int getCutoff() {
        return this.cutoff;
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public void setCutoff(int i) {
        this.cutoff = i;
    }

    public double getSigma() {
        return this.sigma;
    }

    public void setSigma(double d) {
        this.sigma = d;
    }

    public double getSmoothing() {
        return this.smoothing;
    }

    public void setSmoothing(double d) {
        this.smoothing = d;
    }

    @Override // com.joliciel.talismane.machineLearning.ClassificationModelTrainer
    public void setParameters(Config config) {
        this.config = config;
        Config config2 = config.getConfig("talismane.machineLearning");
        Config config3 = config2.getConfig("maxent");
        setCutoff(config2.getInt("cutoff"));
        setIterations(config3.getInt("iterations"));
        setSigma(config3.getDouble("sigma"));
        setSmoothing(config3.getDouble("sigma"));
    }
}
