package org.apache.spark.examples.mllib;

import java.util.HashMap;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;

/* loaded from: input_file:org/apache/spark/examples/mllib/JavaDecisionTree.class */
public final class JavaDecisionTree {
    public static void main(String[] strArr) {
        String str = "data/mllib/sample_libsvm_data.txt";
        if (strArr.length == 1) {
            str = strArr[0];
        } else if (strArr.length > 1) {
            System.err.println("Usage: JavaDecisionTree <libsvm format data file>");
            System.exit(1);
        }
        JavaSparkContext javaSparkContext = new JavaSparkContext(new SparkConf().setAppName("JavaDecisionTree"));
        JavaRDD cache = MLUtils.loadLibSVMFile(javaSparkContext.sc(), str).toJavaRDD().cache();
        Integer valueOf = Integer.valueOf(cache.map(new Function<LabeledPoint, Double>() { // from class: org.apache.spark.examples.mllib.JavaDecisionTree.1
            public Double call(LabeledPoint labeledPoint) {
                return Double.valueOf(labeledPoint.label());
            }
        }).countByValue().size());
        HashMap hashMap = new HashMap();
        Integer num = 5;
        Integer num2 = 100;
        final DecisionTreeModel trainClassifier = DecisionTree.trainClassifier(cache, valueOf.intValue(), hashMap, "gini", num.intValue(), num2.intValue());
        System.out.println("Training error: " + Double.valueOf((1.0d * cache.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { // from class: org.apache.spark.examples.mllib.JavaDecisionTree.2
            public Tuple2<Double, Double> call(LabeledPoint labeledPoint) {
                return new Tuple2<>(Double.valueOf(trainClassifier.predict(labeledPoint.features())), Double.valueOf(labeledPoint.label()));
            }
        }).filter(new Function<Tuple2<Double, Double>, Boolean>() { // from class: org.apache.spark.examples.mllib.JavaDecisionTree.3
            public Boolean call(Tuple2<Double, Double> tuple2) {
                return Boolean.valueOf(!((Double) tuple2._1()).equals(tuple2._2()));
            }
        }).count()) / cache.count()));
        System.out.println("Learned classification tree model:\n" + trainClassifier);
        final DecisionTreeModel trainRegressor = DecisionTree.trainRegressor(cache, hashMap, "variance", num.intValue(), num2.intValue());
        System.out.println("Training Mean Squared Error: " + Double.valueOf(((Double) cache.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { // from class: org.apache.spark.examples.mllib.JavaDecisionTree.4
            public Tuple2<Double, Double> call(LabeledPoint labeledPoint) {
                return new Tuple2<>(Double.valueOf(trainRegressor.predict(labeledPoint.features())), Double.valueOf(labeledPoint.label()));
            }
        }).map(new Function<Tuple2<Double, Double>, Double>() { // from class: org.apache.spark.examples.mllib.JavaDecisionTree.6
            public Double call(Tuple2<Double, Double> tuple2) {
                Double valueOf2 = Double.valueOf(((Double) tuple2._1()).doubleValue() - ((Double) tuple2._2()).doubleValue());
                return Double.valueOf(valueOf2.doubleValue() * valueOf2.doubleValue());
            }
        }).reduce(new Function2<Double, Double, Double>() { // from class: org.apache.spark.examples.mllib.JavaDecisionTree.5
            public Double call(Double d, Double d2) {
                return Double.valueOf(d.doubleValue() + d2.doubleValue());
            }
        })).doubleValue() / cache.count()));
        System.out.println("Learned regression tree model:\n" + trainRegressor);
        javaSparkContext.stop();
    }
}
