package examples.ml.example4;

import datasets.DenseMatrixSet;
import datasets.VectorDouble;
import datastructs.RowBuilder;
import datastructs.RowType;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import maths.functions.distances.EuclideanVectorCalculator;
import ml.classifiers.ThreadedKNNClassifier;
import ml.classifiers.utils.ClassificationVoter;
import parallel.partitioners.MatrixRowPartitionPolicy;
import parallel.partitioners.RangePartitioner;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;
import utils.Pair;
import utils.PairBuilder;
import utils.TableDataSetLoader;

/* loaded from: input_file:examples/ml/example4/Example4.class */
public class Example4 {
    public static Pair<DenseMatrixSet<Double>, List<Integer>> createDataSet() throws IOException, IllegalArgumentException {
        Table loadDataSet = TableDataSetLoader.loadDataSet(new File("src/main/resources/datasets/iris_data.csv"));
        ArrayList arrayList = new ArrayList();
        Column column = loadDataSet.column("species");
        for (int i = 0; i < column.size(); i++) {
            String str = (String) column.get(i);
            if (str.equals("Iris-setosa")) {
                arrayList.add(0);
            } else if (str.equals("Iris-versicolor")) {
                arrayList.add(1);
            } else {
                if (!str.equals("Iris-virginica")) {
                    throw new IllegalArgumentException("Unknown class");
                }
                arrayList.add(2);
            }
        }
        Table first = loadDataSet.removeColumns(new String[]{"species"}).first(loadDataSet.rowCount());
        DenseMatrixSet denseMatrixSet = new DenseMatrixSet(RowType.Type.DOUBLE_VECTOR, new RowBuilder());
        denseMatrixSet.initializeFrom(first);
        denseMatrixSet.setPartitionPolicy(new MatrixRowPartitionPolicy(RangePartitioner.partition(0, denseMatrixSet.m(), 4)));
        return PairBuilder.makePair(denseMatrixSet, arrayList);
    }

    public static void main(String[] strArr) throws IOException, IllegalArgumentException {
        Pair<DenseMatrixSet<Double>, List<Integer>> createDataSet = createDataSet();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(4);
        System.out.println("Number of rows: " + createDataSet.first.m());
        System.out.println("Number of labels: " + createDataSet.second.size());
        ThreadedKNNClassifier threadedKNNClassifier = new ThreadedKNNClassifier(3, false, newFixedThreadPool);
        threadedKNNClassifier.setDistanceCalculator(new EuclideanVectorCalculator());
        threadedKNNClassifier.setMajorityVoter(new ClassificationVoter());
        threadedKNNClassifier.train(createDataSet.first, createDataSet.second);
        VectorDouble vectorDouble = new VectorDouble(Double.valueOf(5.9d), Double.valueOf(3.0d), Double.valueOf(5.1d), Double.valueOf(1.8d));
        System.out.println("Point " + vectorDouble + " has class index " + threadedKNNClassifier.predict(vectorDouble));
        newFixedThreadPool.shutdown();
    }
}
