package examples.ml.example2;

import datasets.DenseMatrixSet;
import datasets.VectorDouble;
import datastructs.RowBuilder;
import datastructs.RowType;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import maths.errorfunctions.MSEVectorFunction;
import maths.functions.NonLinearVectorPolynomial;
import maths.functions.ScalarMonomial;
import ml.regression.NonLinearRegressor;
import optimization.GDInput;
import optimization.GradientDescent;
import tech.tablesaw.api.Table;
import utils.DefaultIterativeAlgorithmController;
import utils.IterativeAlgorithmResult;
import utils.TableDataSetLoader;

/* loaded from: input_file:examples/ml/example2/Example2.class */
public class Example2 {
    public static void main(String[] strArr) throws IOException {
        Table loadDataSet = TableDataSetLoader.loadDataSet(new File("src/main/resources/datasets/car_plant.csv"));
        VectorDouble vectorDouble = new VectorDouble(loadDataSet, "Electricity Usage");
        Table first = loadDataSet.removeColumns(new String[]{"Electricity Usage"}).first(loadDataSet.rowCount());
        DenseMatrixSet denseMatrixSet = new DenseMatrixSet(RowType.Type.DOUBLE_VECTOR, new RowBuilder(), first.rowCount(), 2, Double.valueOf(1.0d));
        denseMatrixSet.setColumn(1, first.doubleColumn(0));
        denseMatrixSet.duplicateColumn(1);
        NonLinearVectorPolynomial nonLinearVectorPolynomial = new NonLinearVectorPolynomial(new ScalarMonomial(0, 0.0d), new ScalarMonomial(1, 0.0d), new ScalarMonomial(2, 0.0d));
        NonLinearRegressor nonLinearRegressor = new NonLinearRegressor(nonLinearVectorPolynomial);
        GDInput gDInput = new GDInput();
        gDInput.showIterations = true;
        gDInput.eta = 0.001d;
        gDInput.errF = new MSEVectorFunction(nonLinearVectorPolynomial);
        gDInput.iterationContorller = new DefaultIterativeAlgorithmController(10000, 1.0E-8d);
        System.out.println((IterativeAlgorithmResult) nonLinearRegressor.train(denseMatrixSet, vectorDouble, new GradientDescent(gDInput)));
        PrintStream printStream = System.out;
        double coeff = nonLinearVectorPolynomial.getCoeff(0);
        double coeff2 = nonLinearVectorPolynomial.getCoeff(1);
        nonLinearVectorPolynomial.getCoeff(2);
        printStream.println("Intercept: " + coeff + " slope 1: " + printStream + " slope 2" + coeff2);
    }
}
