package weka.classifiers.meta;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.trees.RandomTree;
import weka.core.BatchPredictor;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.PartitionGenerator;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: input_file:weka/classifiers/meta/RandomCommittee.class */
public class RandomCommittee extends RandomizableParallelIteratedSingleClassifierEnhancer implements WeightedInstancesHandler, PartitionGenerator {
    static final long serialVersionUID = -9204394360557300093L;
    protected Instances m_data;

    public RandomCommittee() {
        this.m_Classifier = new RandomTree();
    }

    @Override // weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.RandomTree";
    }

    public String globalInfo() {
        return "Class for building an ensemble of randomizable base classifiers. Each base classifiers is built using a different random number seed (but based one the same data). The final prediction is a straight average of the predictions generated by the individual base classifiers.";
    }

    @Override // weka.classifiers.ParallelIteratedSingleClassifierEnhancer, weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.m_data = new Instances(instances);
        super.buildClassifier(this.m_data);
        if (!(this.m_Classifier instanceof Randomizable)) {
            throw new IllegalArgumentException("Base learner must implement Randomizable!");
        }
        this.m_Classifiers = AbstractClassifier.makeCopies(this.m_Classifier, this.m_NumIterations);
        Random randomNumberGenerator = this.m_data.getRandomNumberGenerator(this.m_Seed);
        if (!(this.m_Classifier instanceof WeightedInstancesHandler) && !this.m_data.allInstanceWeightsIdentical()) {
            this.m_data = this.m_data.resampleWithWeights(randomNumberGenerator);
        }
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            ((Randomizable) this.m_Classifiers[i]).setSeed(randomNumberGenerator.nextInt());
        }
        buildClassifiers();
        this.m_data = null;
    }

    @Override // weka.classifiers.ParallelIteratedSingleClassifierEnhancer
    protected synchronized Instances getTrainingSet(int i) throws Exception {
        return this.m_data;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArr = new double[instance.numClasses()];
        double d = 0.0d;
        for (int i = 0; i < this.m_NumIterations; i++) {
            if (instance.classAttribute().isNumeric()) {
                double classifyInstance = this.m_Classifiers[i].classifyInstance(instance);
                if (!Utils.isMissingValue(classifyInstance)) {
                    dArr[0] = dArr[0] + classifyInstance;
                    d += 1.0d;
                }
            } else {
                double[] distributionForInstance = this.m_Classifiers[i].distributionForInstance(instance);
                for (int i2 = 0; i2 < distributionForInstance.length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + distributionForInstance[i2];
                }
            }
        }
        if (instance.classAttribute().isNumeric()) {
            if (d == KStarConstants.FLOOR) {
                dArr[0] = Utils.missingValue();
            } else {
                dArr[0] = dArr[0] / d;
            }
            return dArr;
        }
        if (Utils.eq(Utils.sum(dArr), KStarConstants.FLOOR)) {
            return dArr;
        }
        Utils.normalize(dArr);
        return dArr;
    }

    @Override // weka.classifiers.AbstractClassifier
    public String batchSizeTipText() {
        return "Batch size to use if base learner is a BatchPredictor";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public void setBatchSize(String str) {
        if (getClassifier() instanceof BatchPredictor) {
            ((BatchPredictor) getClassifier()).setBatchSize(str);
        } else {
            super.setBatchSize(str);
        }
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public String getBatchSize() {
        return getClassifier() instanceof BatchPredictor ? ((BatchPredictor) getClassifier()).getBatchSize() : super.getBatchSize();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public double[][] distributionsForInstances(final Instances instances) throws Exception {
        if (!(getClassifier() instanceof BatchPredictor)) {
            double[][] dArr = new double[instances.numInstances()][instances.numClasses()];
            for (int i = 0; i < instances.numInstances(); i++) {
                dArr[i] = distributionForInstance(instances.instance(i));
            }
            return dArr;
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.m_numExecutionSlots);
        int length = this.m_Classifiers.length / this.m_numExecutionSlots;
        HashSet hashSet = new HashSet();
        int i2 = 0;
        while (i2 < this.m_numExecutionSlots) {
            final int i3 = i2 * length;
            final int length2 = i2 < this.m_numExecutionSlots - 1 ? i3 + length : this.m_Classifiers.length;
            hashSet.add(newFixedThreadPool.submit(new Callable<double[][]>() { // from class: weka.classifiers.meta.RandomCommittee.1
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public double[][] call() throws Exception {
                    if (instances.classAttribute().isNumeric()) {
                        double[][] dArr2 = new double[instances.numInstances()][2];
                        for (int i4 = i3; i4 < length2; i4++) {
                            double[][] distributionsForInstances = ((BatchPredictor) RandomCommittee.this.m_Classifiers[i4]).distributionsForInstances(instances);
                            for (int i5 = 0; i5 < distributionsForInstances.length; i5++) {
                                if (!Utils.isMissingValue(distributionsForInstances[i5][0])) {
                                    double[] dArr3 = dArr2[i5];
                                    dArr3[0] = dArr3[0] + distributionsForInstances[i5][0];
                                    double[] dArr4 = dArr2[i5];
                                    dArr4[1] = dArr4[1] + 1.0d;
                                }
                            }
                        }
                        return dArr2;
                    }
                    double[][] dArr5 = new double[instances.numInstances()][instances.numClasses()];
                    for (int i6 = i3; i6 < length2; i6++) {
                        double[][] distributionsForInstances2 = ((BatchPredictor) RandomCommittee.this.m_Classifiers[i6]).distributionsForInstances(instances);
                        for (int i7 = 0; i7 < distributionsForInstances2.length; i7++) {
                            for (int i8 = 0; i8 < distributionsForInstances2[i7].length; i8++) {
                                double[] dArr6 = dArr5[i7];
                                int i9 = i8;
                                dArr6[i9] = dArr6[i9] + distributionsForInstances2[i7][i8];
                            }
                        }
                    }
                    return dArr5;
                }
            }));
            i2++;
        }
        double[][] dArr2 = new double[instances.numInstances()][instances.classAttribute().isNumeric() ? 2 : instances.numClasses()];
        try {
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                double[][] dArr3 = (double[][]) ((Future) it.next()).get();
                for (int i4 = 0; i4 < dArr3.length; i4++) {
                    for (int i5 = 0; i5 < dArr3[i4].length; i5++) {
                        double[] dArr4 = dArr2[i4];
                        int i6 = i5;
                        dArr4[i6] = dArr4[i6] + dArr3[i4][i5];
                    }
                }
            }
        } catch (Exception e) {
            System.out.println("RandomCommittee: predictions could not be generated by thread.");
            e.printStackTrace();
        }
        newFixedThreadPool.shutdown();
        if (!instances.classAttribute().isNumeric()) {
            for (int i7 = 0; i7 < dArr2.length; i7++) {
                double sum = Utils.sum(dArr2[i7]);
                if (!Utils.eq(sum, KStarConstants.FLOOR)) {
                    Utils.normalize(dArr2[i7], sum);
                }
            }
            return dArr2;
        }
        double[][] dArr5 = new double[dArr2.length][1];
        for (int i8 = 0; i8 < dArr2.length; i8++) {
            if (dArr2[i8][1] == KStarConstants.FLOOR) {
                dArr5[i8][0] = Utils.missingValue();
            } else {
                dArr5[i8][0] = dArr2[i8][0] / dArr2[i8][1];
            }
        }
        return dArr5;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.BatchPredictor
    public boolean implementsMoreEfficientBatchPrediction() {
        return !(getClassifier() instanceof BatchPredictor) ? super.implementsMoreEfficientBatchPrediction() : ((BatchPredictor) getClassifier()).implementsMoreEfficientBatchPrediction();
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "RandomCommittee: No model built yet.";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("All the base classifiers: \n\n");
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            stringBuffer.append(this.m_Classifiers[i].toString() + "\n\n");
        }
        return stringBuffer.toString();
    }

    @Override // weka.core.PartitionGenerator
    public void generatePartition(Instances instances) throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + getClassifierSpec() + " cannot generate a partition");
        }
        buildClassifier(instances);
    }

    @Override // weka.core.PartitionGenerator
    public double[] getMembershipValues(Instance instance) throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + getClassifierSpec() + " cannot generate a partition");
        }
        ArrayList arrayList = new ArrayList();
        int i = 0;
        for (int i2 = 0; i2 < this.m_Classifiers.length; i2++) {
            double[] membershipValues = ((PartitionGenerator) this.m_Classifiers[i2]).getMembershipValues(instance);
            i += membershipValues.length;
            arrayList.add(membershipValues);
        }
        double[] dArr = new double[i];
        int i3 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            double[] dArr2 = (double[]) it.next();
            System.arraycopy(dArr2, 0, dArr, i3, dArr2.length);
            i3 += dArr2.length;
        }
        return dArr;
    }

    @Override // weka.core.PartitionGenerator
    public int numElements() throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + getClassifierSpec() + " cannot generate a partition");
        }
        int i = 0;
        for (int i2 = 0; i2 < this.m_Classifiers.length; i2++) {
            i += ((PartitionGenerator) this.m_Classifiers[i2]).numElements();
        }
        return i;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 15800 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new RandomCommittee(), strArr);
    }
}
