package org.jpmml.evaluator.naive_bayes;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.math3.util.Precision;
import org.dmg.pmml.ContinuousDistribution;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Discretize;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GaussianDistribution;
import org.dmg.pmml.HasType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.PoissonDistribution;
import org.dmg.pmml.naive_bayes.BayesInput;
import org.dmg.pmml.naive_bayes.BayesInputs;
import org.dmg.pmml.naive_bayes.BayesOutput;
import org.dmg.pmml.naive_bayes.NaiveBayesModel;
import org.dmg.pmml.naive_bayes.PMMLAttributes;
import org.dmg.pmml.naive_bayes.PMMLElements;
import org.dmg.pmml.naive_bayes.PairCounts;
import org.dmg.pmml.naive_bayes.TargetValueCount;
import org.dmg.pmml.naive_bayes.TargetValueCounts;
import org.dmg.pmml.naive_bayes.TargetValueStat;
import org.dmg.pmml.naive_bayes.TargetValueStats;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.DiscretizationUtil;
import org.jpmml.evaluator.DistributionUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValueUtil;
import org.jpmml.evaluator.FieldValues;
import org.jpmml.evaluator.Functions;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.MapHolder;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NumberUtil;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.ProbabilityDistribution;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueUtil;
import org.jpmml.evaluator.VerificationUtil;
import org.jpmml.model.XPathUtil;

/* loaded from: input_file:org/jpmml/evaluator/naive_bayes/NaiveBayesModelEvaluator.class */
public class NaiveBayesModelEvaluator extends ModelEvaluator<NaiveBayesModel> {
    private transient List<BayesInput> bayesInputs;
    private transient Map<FieldName, Map<Object, Number>> fieldCountSums;
    private static final LoadingCache<NaiveBayesModel, List<BayesInput>> bayesInputCache = CacheUtil.buildLoadingCache(new CacheLoader<NaiveBayesModel, List<BayesInput>>() { // from class: org.jpmml.evaluator.naive_bayes.NaiveBayesModelEvaluator.2
        public List<BayesInput> load(NaiveBayesModel naiveBayesModel) {
            return ImmutableList.copyOf(NaiveBayesModelEvaluator.parseBayesInputs(naiveBayesModel));
        }
    });
    private static final LoadingCache<NaiveBayesModel, Map<FieldName, Map<Object, Number>>> fieldCountSumCache = CacheUtil.buildLoadingCache(new CacheLoader<NaiveBayesModel, Map<FieldName, Map<Object, Number>>>() { // from class: org.jpmml.evaluator.naive_bayes.NaiveBayesModelEvaluator.3
        public Map<FieldName, Map<Object, Number>> load(NaiveBayesModel naiveBayesModel) {
            return ImmutableMap.copyOf(NaiveBayesModelEvaluator.calculateFieldCountSums(naiveBayesModel));
        }
    });

    public NaiveBayesModelEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, NaiveBayesModel.class));
    }

    public NaiveBayesModelEvaluator(PMML pmml, NaiveBayesModel naiveBayesModel) {
        super(pmml, naiveBayesModel);
        this.bayesInputs = null;
        this.fieldCountSums = null;
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        if (bayesInputs == null) {
            throw new MissingElementException((PMMLObject) naiveBayesModel, PMMLElements.NAIVEBAYESMODEL_BAYESINPUTS);
        }
        if (!bayesInputs.hasBayesInputs() && !bayesInputs.hasExtensions()) {
            throw new MissingElementException((PMMLObject) bayesInputs, PMMLElements.BAYESINPUTS_BAYESINPUTS);
        }
        BayesOutput bayesOutput = naiveBayesModel.getBayesOutput();
        if (bayesOutput == null) {
            throw new MissingElementException((PMMLObject) naiveBayesModel, PMMLElements.NAIVEBAYESMODEL_BAYESOUTPUT);
        }
        TargetValueCounts targetValueCounts = bayesOutput.getTargetValueCounts();
        if (targetValueCounts == null) {
            throw new MissingElementException((PMMLObject) bayesOutput, PMMLElements.BAYESOUTPUT_TARGETVALUECOUNTS);
        }
        if (!targetValueCounts.hasTargetValueCounts()) {
            throw new MissingElementException((PMMLObject) targetValueCounts, PMMLElements.TARGETVALUECOUNTS_TARGETVALUECOUNTS);
        }
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Naive Bayes model";
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    protected <V extends Number> Map<FieldName, ? extends Classification<?, V>> evaluateClassification(final ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        NaiveBayesModel model = getModel();
        BayesOutput bayesOutput = model.getBayesOutput();
        TargetField targetField = getTargetField();
        FieldName field = bayesOutput.getField();
        if (field == null) {
            throw new MissingAttributeException((PMMLObject) bayesOutput, PMMLAttributes.BAYESOUTPUT_FIELD);
        }
        if (field != null && !Objects.equals(targetField.getFieldName(), field)) {
            throw new InvalidAttributeException(bayesOutput, PMMLAttributes.BAYESOUTPUT_FIELD, field);
        }
        ProbabilityMap<Object, V> probabilityMap = new ProbabilityMap<Object, V>() { // from class: org.jpmml.evaluator.naive_bayes.NaiveBayesModelEvaluator.1
            @Override // org.jpmml.evaluator.ValueMap
            public ValueFactory<V> getValueFactory() {
                return valueFactory;
            }

            @Override // org.jpmml.evaluator.naive_bayes.ProbabilityMap
            public void multiply(Object obj, Number number) {
                ensureValue(obj).add((Value<? extends Number>) getValueFactory().newValue(number).ln2());
            }
        };
        calculatePriorProbabilities(probabilityMap, getTargetValueCounts(bayesOutput));
        Number threshold = model.getThreshold();
        if (threshold == null) {
            throw new MissingAttributeException((PMMLObject) model, PMMLAttributes.NAIVEBAYESMODEL_THRESHOLD);
        }
        Map<FieldName, Map<Object, Number>> fieldCountSums = getFieldCountSums();
        for (BayesInput bayesInput : getBayesInputs()) {
            FieldName field2 = bayesInput.getField();
            if (field2 == null) {
                throw new MissingAttributeException((PMMLObject) bayesInput, PMMLAttributes.BAYESINPUT_FIELD);
            }
            FieldValue evaluate = evaluationContext.evaluate(field2);
            if (!FieldValueUtil.isMissing(evaluate)) {
                TargetValueStats targetValueStats = getTargetValueStats(bayesInput);
                if (targetValueStats != null) {
                    calculateContinuousProbabilities(probabilityMap, targetValueStats, threshold, evaluate);
                } else {
                    DerivedField derivedField = bayesInput.getDerivedField();
                    if (derivedField != null) {
                        evaluate = discretize(derivedField, evaluate);
                        if (FieldValueUtil.isMissing(evaluate)) {
                        }
                    }
                    Map<Object, Number> map = fieldCountSums.get(field2);
                    TargetValueCounts targetValueCounts = getTargetValueCounts(bayesInput, evaluate);
                    if (targetValueCounts != null) {
                        calculateDiscreteProbabilities(probabilityMap, targetValueCounts, threshold, map);
                    }
                }
            }
        }
        ValueUtil.normalizeSoftMax(probabilityMap);
        return TargetUtil.evaluateClassification(targetField, new ProbabilityDistribution(probabilityMap));
    }

    private FieldValue discretize(DerivedField derivedField, FieldValue fieldValue) {
        Discretize ensureExpression = ExpressionUtil.ensureExpression(derivedField);
        if (!(ensureExpression instanceof Discretize)) {
            throw new MisplacedElementException(ensureExpression);
        }
        FieldValue discretize = DiscretizationUtil.discretize(ensureExpression, fieldValue);
        return FieldValueUtil.isMissing(discretize) ? FieldValues.MISSING_VALUE : discretize.cast((HasType<?>) derivedField);
    }

    private void calculateContinuousProbabilities(ProbabilityMap<Object, ?> probabilityMap, TargetValueStats targetValueStats, Number number, FieldValue fieldValue) {
        Number asNumber = fieldValue.asNumber();
        Iterator it = targetValueStats.iterator();
        while (it.hasNext()) {
            TargetValueStat targetValueStat = (TargetValueStat) it.next();
            Object value = targetValueStat.getValue();
            if (value == null) {
                throw new MissingAttributeException((PMMLObject) targetValueStat, PMMLAttributes.TARGETVALUESTAT_VALUE);
            }
            ContinuousDistribution continuousDistribution = targetValueStat.getContinuousDistribution();
            if (continuousDistribution == null) {
                throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(targetValueStat.getClass()) + "/<ContinuousDistribution>"), (PMMLObject) targetValueStat);
            }
            if (!(continuousDistribution instanceof GaussianDistribution) && !(continuousDistribution instanceof PoissonDistribution)) {
                throw new MisplacedElementException(continuousDistribution);
            }
            Number valueOf = Double.valueOf(DistributionUtil.probability(continuousDistribution, asNumber));
            if (NumberUtil.compare(valueOf, number) < 0) {
                valueOf = number;
            }
            probabilityMap.multiply(value, valueOf);
        }
    }

    private void calculateDiscreteProbabilities(ProbabilityMap<Object, ?> probabilityMap, TargetValueCounts targetValueCounts, Number number, Map<?, Number> map) {
        Number evaluate;
        Iterator it = targetValueCounts.iterator();
        while (it.hasNext()) {
            TargetValueCount targetValueCount = (TargetValueCount) it.next();
            Object value = targetValueCount.getValue();
            if (value == null) {
                throw new MissingAttributeException((PMMLObject) targetValueCount, PMMLAttributes.TARGETVALUECOUNT_VALUE);
            }
            Number count = targetValueCount.getCount();
            if (count == null) {
                throw new MissingAttributeException((PMMLObject) targetValueCount, PMMLAttributes.TARGETVALUECOUNT_COUNT);
            }
            if (VerificationUtil.isZero(count, Precision.EPSILON)) {
                evaluate = number;
            } else {
                evaluate = Functions.DIVIDE.evaluate(count, NumberUtil.asDouble(map.get(value)));
            }
            probabilityMap.multiply(value, evaluate);
        }
    }

    private void calculatePriorProbabilities(ProbabilityMap<Object, ?> probabilityMap, TargetValueCounts targetValueCounts) {
        Iterator it = targetValueCounts.iterator();
        while (it.hasNext()) {
            TargetValueCount targetValueCount = (TargetValueCount) it.next();
            Object value = targetValueCount.getValue();
            if (value == null) {
                throw new MissingAttributeException((PMMLObject) targetValueCount, PMMLAttributes.TARGETVALUECOUNT_VALUE);
            }
            Number count = targetValueCount.getCount();
            if (count == null) {
                throw new MissingAttributeException((PMMLObject) targetValueCount, PMMLAttributes.TARGETVALUECOUNT_COUNT);
            }
            probabilityMap.multiply(value, count);
        }
    }

    protected List<BayesInput> getBayesInputs() {
        if (this.bayesInputs == null) {
            this.bayesInputs = (List) getValue(bayesInputCache);
        }
        return this.bayesInputs;
    }

    protected Map<FieldName, Map<Object, Number>> getFieldCountSums() {
        if (this.fieldCountSums == null) {
            this.fieldCountSums = (Map) getValue(fieldCountSumCache);
        }
        return this.fieldCountSums;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<FieldName, Map<Object, Number>> calculateFieldCountSums(NaiveBayesModel naiveBayesModel) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (BayesInput bayesInput : (List) CacheUtil.getValue(naiveBayesModel, bayesInputCache)) {
            FieldName field = bayesInput.getField();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            Iterator it = bayesInput.getPairCounts().iterator();
            while (it.hasNext()) {
                Iterator it2 = ((PairCounts) it.next()).getTargetValueCounts().iterator();
                while (it2.hasNext()) {
                    TargetValueCount targetValueCount = (TargetValueCount) it2.next();
                    Object value = targetValueCount.getValue();
                    if (value == null) {
                        throw new MissingAttributeException((PMMLObject) targetValueCount, PMMLAttributes.TARGETVALUECOUNT_VALUE);
                    }
                    Number count = targetValueCount.getCount();
                    if (count == null) {
                        throw new MissingAttributeException((PMMLObject) targetValueCount, PMMLAttributes.TARGETVALUECOUNT_COUNT);
                    }
                    Number number = (Number) linkedHashMap2.get(value);
                    linkedHashMap2.put(value, number == null ? count : Functions.ADD.evaluate(number, count));
                }
            }
            linkedHashMap.put(field, linkedHashMap2);
        }
        return linkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static List<BayesInput> parseBayesInputs(NaiveBayesModel naiveBayesModel) {
        BayesInputs bayesInputs = naiveBayesModel.getBayesInputs();
        if (!bayesInputs.hasExtensions()) {
            return bayesInputs.getBayesInputs();
        }
        ArrayList arrayList = new ArrayList(bayesInputs.getBayesInputs());
        Iterator it = bayesInputs.getExtensions().iterator();
        while (it.hasNext()) {
            for (Object obj : ((Extension) it.next()).getContent()) {
                if (obj instanceof BayesInput) {
                    arrayList.add((BayesInput) obj);
                }
            }
        }
        return arrayList;
    }

    private static TargetValueStats getTargetValueStats(BayesInput bayesInput) {
        return bayesInput.getTargetValueStats();
    }

    private static TargetValueCounts getTargetValueCounts(BayesInput bayesInput, FieldValue fieldValue) {
        if (bayesInput instanceof MapHolder) {
            return (TargetValueCounts) ((MapHolder) bayesInput).get(fieldValue.getDataType(), fieldValue.getValue());
        }
        for (PairCounts pairCounts : bayesInput.getPairCounts()) {
            Object value = pairCounts.getValue();
            if (value == null) {
                throw new MissingAttributeException((PMMLObject) pairCounts, PMMLAttributes.PAIRCOUNTS_VALUE);
            }
            if (fieldValue.equalsValue(value)) {
                TargetValueCounts targetValueCounts = pairCounts.getTargetValueCounts();
                if (targetValueCounts == null) {
                    throw new MissingElementException((PMMLObject) pairCounts, PMMLElements.PAIRCOUNTS_TARGETVALUECOUNTS);
                }
                return targetValueCounts;
            }
        }
        return null;
    }

    private static TargetValueCounts getTargetValueCounts(BayesOutput bayesOutput) {
        return bayesOutput.getTargetValueCounts();
    }
}
