package examples.ml.example6;

import datasets.DenseMatrixSet;
import datasets.VectorDouble;
import datastructs.RowBuilder;
import datastructs.RowType;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import maths.errorfunctions.LogisticMSEVectorFunction;
import maths.functions.LinearVectorPolynomial;
import maths.functions.SigmoidFunction;
import ml.classifiers.LogisticRegressionClassifier;
import optimization.GDInput;
import optimization.GradientDescent;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;
import utils.DefaultIterativeAlgorithmController;
import utils.IterativeAlgorithmResult;
import utils.Pair;
import utils.PairBuilder;
import utils.TableDataSetLoader;

/* loaded from: input_file:examples/ml/example6/Example6.class */
public class Example6 {
    public static Pair<DenseMatrixSet, VectorDouble> createDataSet() throws IOException, IllegalArgumentException {
        Table loadDataSet = TableDataSetLoader.loadDataSet(new File("src/main/resources/datasets/iris_dataset_reduced.csv"));
        Column column = loadDataSet.column("species");
        VectorDouble vectorDouble = new VectorDouble(column.size());
        for (int i = 0; i < column.size(); i++) {
            String str = (String) column.get(i);
            if (str.equals("Iris-setosa")) {
                vectorDouble.set(i, 0.0d);
            } else {
                if (!str.equals("Iris-versicolor")) {
                    throw new IllegalArgumentException("Unknown class");
                }
                vectorDouble.set(i, 1.0d);
            }
        }
        Table first = loadDataSet.removeColumns(new String[]{"species"}).first(loadDataSet.rowCount());
        DenseMatrixSet denseMatrixSet = new DenseMatrixSet(RowType.Type.DOUBLE_VECTOR, new RowBuilder(), first.rowCount(), first.columnCount() + 1, Double.valueOf(1.0d));
        denseMatrixSet.setColumn(1, first.doubleColumn(0));
        denseMatrixSet.setColumn(2, first.doubleColumn(1));
        denseMatrixSet.setColumn(3, first.doubleColumn(2));
        denseMatrixSet.setColumn(4, first.doubleColumn(3));
        return PairBuilder.makePair(denseMatrixSet, vectorDouble);
    }

    public static void main(String[] strArr) throws IOException, IllegalArgumentException {
        Pair<DenseMatrixSet, VectorDouble> createDataSet = createDataSet();
        System.out.println("Number of rows: " + createDataSet.first.m());
        System.out.println("Number of labels: " + createDataSet.second.size());
        SigmoidFunction sigmoidFunction = new SigmoidFunction(new LinearVectorPolynomial(4));
        GDInput gDInput = new GDInput();
        gDInput.showIterations = true;
        gDInput.eta = 0.01d;
        gDInput.errF = new LogisticMSEVectorFunction(sigmoidFunction);
        gDInput.iterationContorller = new DefaultIterativeAlgorithmController(100000, 1.0E-8d);
        LogisticRegressionClassifier logisticRegressionClassifier = new LogisticRegressionClassifier(sigmoidFunction, new GradientDescent(gDInput));
        IterativeAlgorithmResult iterativeAlgorithmResult = (IterativeAlgorithmResult) logisticRegressionClassifier.train(createDataSet.first, createDataSet.second);
        System.out.println(" ");
        System.out.println(iterativeAlgorithmResult);
        PrintStream printStream = System.out;
        double coeff = sigmoidFunction.getCoeff(0);
        double coeff2 = sigmoidFunction.getCoeff(1);
        sigmoidFunction.getCoeff(2);
        sigmoidFunction.getCoeff(3);
        printStream.println("Intercept: " + coeff + " slope1: " + printStream + " slope2: " + coeff2 + " slope3: " + printStream);
        VectorDouble vectorDouble = new VectorDouble(Double.valueOf(1.0d), Double.valueOf(5.7d), Double.valueOf(2.8d), Double.valueOf(4.1d), Double.valueOf(1.3d));
        System.out.println("Point " + vectorDouble + " has class index " + logisticRegressionClassifier.predict(vectorDouble));
    }
}
