/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.spark;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.feature.ColumnPruner;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.ResultFeature;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.OutputField;
import org.jpmml.evaluator.PMMLAttributes;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.spark.ColumnExploder;
import org.jpmml.evaluator.spark.ColumnProducer;
import org.jpmml.evaluator.spark.OutputColumnProducer;
import org.jpmml.evaluator.spark.PMMLTransformer;
import org.jpmml.evaluator.spark.ProbabilityColumnProducer;
import org.jpmml.evaluator.spark.TargetColumnProducer;
import scala.collection.immutable.Set;

public class TransformerBuilder {
    private Evaluator evaluator = null;
    private List<ColumnProducer<? extends ResultField>> columnProducers = new ArrayList<ColumnProducer<? extends ResultField>>();
    private boolean exploded = false;

    public TransformerBuilder(Evaluator evaluator) {
        this.setEvaluator(evaluator);
    }

    public TransformerBuilder withTargetCols() {
        Evaluator evaluator = this.getEvaluator();
        List targetFields = evaluator.getTargetFields();
        for (TargetField targetField : targetFields) {
            this.columnProducers.add(new TargetColumnProducer(targetField, null));
        }
        return this;
    }

    public TransformerBuilder withOutputCols() {
        Evaluator evaluator = this.getEvaluator();
        List outputFields = evaluator.getOutputFields();
        for (OutputField outputField : outputFields) {
            this.columnProducers.add(new OutputColumnProducer(outputField, null));
        }
        return this;
    }

    public TransformerBuilder withLabelCol(String columnName) {
        Evaluator evaluator = this.getEvaluator();
        TargetField targetField = TransformerBuilder.getTargetField(evaluator);
        this.columnProducers.add(new TargetColumnProducer(targetField, columnName));
        return this;
    }

    public TransformerBuilder withProbabilityCol(String columnName) {
        return this.withProbabilityCol(columnName, null);
    }

    public TransformerBuilder withProbabilityCol(String columnName, List<String> labels) {
        Evaluator evaluator = this.getEvaluator();
        TargetField targetField = TransformerBuilder.getTargetField(evaluator);
        List<OutputField> probabilityOutputFields = TransformerBuilder.getProbabilityFields(evaluator, targetField);
        List targetCategories = probabilityOutputFields.stream().map(probabilityOutputField -> {
            org.dmg.pmml.OutputField pmmlOutputField = probabilityOutputField.getField();
            String value = pmmlOutputField.getValue();
            if (value == null) {
                throw new MissingAttributeException((PMMLObject)pmmlOutputField, PMMLAttributes.OUTPUTFIELD_VALUE);
            }
            return value;
        }).collect(Collectors.toList());
        if (!(labels == null || labels.size() == targetCategories.size() && labels.containsAll(targetCategories))) {
            throw new IllegalArgumentException("Model has an incompatible set of probability-type output fields (expected " + labels + ", got " + targetCategories + ")");
        }
        this.columnProducers.add(new ProbabilityColumnProducer(targetField, columnName, labels != null ? labels : targetCategories));
        return this;
    }

    public TransformerBuilder exploded(boolean exploded) {
        this.exploded = exploded;
        return this;
    }

    public Transformer build() {
        Evaluator evaluator = this.getEvaluator();
        PMMLTransformer pmmlTransformer = new PMMLTransformer(evaluator, this.columnProducers);
        if (this.exploded) {
            ColumnExploder columnExploder = new ColumnExploder(pmmlTransformer.getOutputCol());
            ColumnPruner columnPruner = new ColumnPruner((Set)new Set.Set1((Object)pmmlTransformer.getOutputCol()));
            PipelineModel pipelineModel = new PipelineModel(null, new Transformer[]{pmmlTransformer, columnExploder, columnPruner});
            return pipelineModel;
        }
        return pmmlTransformer;
    }

    private Evaluator getEvaluator() {
        return this.evaluator;
    }

    private void setEvaluator(Evaluator evaluator) {
        this.evaluator = evaluator;
    }

    private static TargetField getTargetField(Evaluator evaluator) {
        List targetFields = evaluator.getTargetFields();
        if (targetFields.size() < 1) {
            throw new IllegalArgumentException("Model does not have a target field");
        }
        if (targetFields.size() > 1) {
            throw new IllegalArgumentException("Model has multiple target fields (" + targetFields + ")");
        }
        return (TargetField)targetFields.get(0);
    }

    private static List<OutputField> getProbabilityFields(Evaluator evaluator, final TargetField targetField) {
        List outputFields = evaluator.getOutputFields();
        Predicate<OutputField> predicate = new Predicate<OutputField>(){

            @Override
            public boolean test(OutputField outputField) {
                org.dmg.pmml.OutputField pmmlOutputField = outputField.getField();
                ResultFeature resultFeature = pmmlOutputField.getResultFeature();
                switch (resultFeature) {
                    case PROBABILITY: {
                        FieldName targetFieldName = pmmlOutputField.getTargetField();
                        return Objects.equals(targetFieldName, null) || Objects.equals(targetFieldName, targetField.getName());
                    }
                }
                return false;
            }
        };
        List<OutputField> probabilityOutputFields = outputFields.stream().filter(predicate).collect(Collectors.toList());
        if (probabilityOutputFields.size() < 1) {
            throw new IllegalArgumentException("Model does not have probability-type output fields");
        }
        return probabilityOutputFields;
    }
}

