package ml.classifiers;

import datastructs.I2DDataSet;
import datastructs.IVector;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import maths.functions.distances.DistanceCalculator;
import ml.classifiers.utils.ClassificationVoter;
import utils.Pair;

/* loaded from: input_file:ml/classifiers/KNNClassifier.class */
public class KNNClassifier<DataType, DataSetType extends I2DDataSet<IVector<DataType>>, DistanceType extends DistanceCalculator, VoterType extends ClassificationVoter> {
    protected int k;
    protected boolean copyDataset;
    protected DataSetType dataSet;
    protected List<Integer> labels;
    DistanceType distanceCalculator;
    VoterType majorityVoter;

    public KNNClassifier(int i, boolean z) {
        this.k = i;
        this.copyDataset = z;
    }

    public int nNeighbors() {
        return this.k;
    }

    public void setDistanceCalculator(DistanceType distancetype) {
        this.distanceCalculator = distancetype;
    }

    public void setMajorityVoter(VoterType votertype) {
        this.majorityVoter = votertype;
    }

    public void train(DataSetType datasettype, List<Integer> list) {
        if (this.copyDataset) {
            this.dataSet = (DataSetType) datasettype.copy();
        } else {
            this.dataSet = datasettype;
            this.labels = list;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <PointType> Integer predict(PointType pointtype) {
        if (this.majorityVoter == null) {
            throw new IllegalStateException(" Majority voter has not been set");
        }
        if (this.distanceCalculator == null) {
            throw new IllegalStateException("Distance calculator has not been set");
        }
        for (int i = 0; i < this.dataSet.m(); i++) {
            this.majorityVoter.addItem(Integer.valueOf(i), this.distanceCalculator.calculate(this.dataSet.getRow(i), pointtype));
        }
        return Integer.valueOf(getTopResult());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getTopResult() {
        List<Pair<Integer, Double>> result = this.majorityVoter.getResult(this.k);
        this.majorityVoter.clear();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < result.size(); i++) {
            int intValue = this.labels.get(result.get(i).first.intValue()).intValue();
            if (hashMap.containsKey(Integer.valueOf(intValue))) {
                hashMap.put(Integer.valueOf(intValue), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(intValue))).intValue() + 1));
            } else {
                hashMap.put(Integer.valueOf(intValue), 1);
            }
        }
        return ((Integer) ((Map.Entry) Collections.max(hashMap.entrySet(), (entry, entry2) -> {
            return ((Integer) entry.getValue()).compareTo((Integer) entry2.getValue());
        })).getKey()).intValue();
    }
}
