package examples.ml.example3;

import datasets.DenseMatrixSet;
import datasets.VectorDouble;
import datastructs.RowBuilder;
import datastructs.RowType;
import java.util.ArrayList;
import maths.functions.distances.EuclideanVectorCalculator;
import ml.classifiers.KNNClassifier;
import ml.classifiers.utils.ClassificationVoter;

/* loaded from: input_file:examples/ml/example3/Example3.class */
public class Example3 {
    public static void main(String[] strArr) {
        DenseMatrixSet denseMatrixSet = new DenseMatrixSet(RowType.Type.DOUBLE_VECTOR, new RowBuilder());
        denseMatrixSet.create(12, 2);
        denseMatrixSet.set(0, Double.valueOf(1.0d), Double.valueOf(3.0d));
        denseMatrixSet.set(1, Double.valueOf(1.5d), Double.valueOf(2.0d));
        denseMatrixSet.set(2, Double.valueOf(2.0d), Double.valueOf(1.0d));
        denseMatrixSet.set(3, Double.valueOf(2.5d), Double.valueOf(4.0d));
        denseMatrixSet.set(4, Double.valueOf(3.0d), Double.valueOf(1.5d));
        denseMatrixSet.set(5, Double.valueOf(3.5d), Double.valueOf(2.5d));
        denseMatrixSet.set(6, Double.valueOf(5.0d), Double.valueOf(5.0d));
        denseMatrixSet.set(7, Double.valueOf(5.5d), Double.valueOf(4.0d));
        denseMatrixSet.set(8, Double.valueOf(6.0d), Double.valueOf(6.0d));
        denseMatrixSet.set(9, Double.valueOf(6.5d), Double.valueOf(4.5d));
        denseMatrixSet.set(10, Double.valueOf(7.0d), Double.valueOf(1.5d));
        denseMatrixSet.set(11, Double.valueOf(8.0d), Double.valueOf(2.5d));
        ArrayList arrayList = new ArrayList(denseMatrixSet.m());
        for (int i = 0; i < 6; i++) {
            arrayList.add(0);
        }
        for (int i2 = 6; i2 < denseMatrixSet.m(); i2++) {
            arrayList.add(1);
        }
        KNNClassifier kNNClassifier = new KNNClassifier(2, false);
        kNNClassifier.setDistanceCalculator(new EuclideanVectorCalculator());
        kNNClassifier.setMajorityVoter(new ClassificationVoter());
        kNNClassifier.train(denseMatrixSet, arrayList);
        VectorDouble vectorDouble = new VectorDouble(Double.valueOf(3.1d), Double.valueOf(2.2d));
        System.out.println("Point " + vectorDouble + " has class index " + kNNClassifier.predict(vectorDouble));
        kNNClassifier.predict(new VectorDouble(Double.valueOf(9.1d), Double.valueOf(6.2d)));
    }
}
