package org.apache.solr.ltr.model;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;

/* loaded from: input_file:org/apache/solr/ltr/model/MultipleAdditiveTreesModel.class */
public class MultipleAdditiveTreesModel extends LTRScoringModel {
    private final HashMap<String, Integer> fname2index;
    private List<RegressionTree> trees;
    private boolean isNullSameAsZero;

    /* loaded from: input_file:org/apache/solr/ltr/model/MultipleAdditiveTreesModel$RegressionTree.class */
    public class RegressionTree {
        private Float weight;
        private RegressionTreeNode root;

        public void setWeight(float f) {
            this.weight = Float.valueOf(f);
        }

        public void setWeight(String str) {
            this.weight = Float.valueOf(str);
        }

        public void setRoot(Object obj) {
            this.root = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map) obj);
        }

        public float score(float[] fArr) {
            return MultipleAdditiveTreesModel.this.isNullSameAsZero ? this.weight.floatValue() * MultipleAdditiveTreesModel.scoreNode(fArr, this.root) : this.weight.floatValue() * MultipleAdditiveTreesModel.scoreNodeWithNullSupport(fArr, this.root);
        }

        public String explain(float[] fArr) {
            return MultipleAdditiveTreesModel.explainNode(fArr, this.root);
        }

        public String toString() {
            return "(weight=" + this.weight + ",root=" + this.root + ")";
        }

        public RegressionTree() {
        }

        public void validate() throws ModelException {
            if (this.weight == null) {
                throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a weight");
            }
            if (this.root == null) {
                throw new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
            }
            MultipleAdditiveTreesModel.validateNode(this.root);
        }
    }

    /* loaded from: input_file:org/apache/solr/ltr/model/MultipleAdditiveTreesModel$RegressionTreeNode.class */
    public class RegressionTreeNode {
        private static final float NODE_SPLIT_SLACK = 1.0E-6f;
        private String feature;
        private Float threshold;
        private RegressionTreeNode left;
        private RegressionTreeNode right;
        private String missing;
        private float value = 0.0f;
        private int featureIndex = -1;

        public void setValue(float f) {
            this.value = f;
        }

        public void setValue(String str) {
            this.value = Float.parseFloat(str);
        }

        public void setMissing(String str) {
            this.missing = str;
        }

        public void setFeature(String str) {
            this.feature = str;
            Integer num = MultipleAdditiveTreesModel.this.fname2index.get(this.feature);
            this.featureIndex = num == null ? -1 : num.intValue();
        }

        public void setThreshold(float f) {
            this.threshold = Float.valueOf(f + NODE_SPLIT_SLACK);
        }

        public void setThreshold(String str) {
            this.threshold = Float.valueOf(Float.parseFloat(str) + NODE_SPLIT_SLACK);
        }

        public void setLeft(Object obj) {
            this.left = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map) obj);
        }

        public void setRight(Object obj) {
            this.right = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map) obj);
        }

        public boolean isLeaf() {
            return this.feature == null;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            if (isLeaf()) {
                sb.append(this.value);
            } else {
                sb.append("(feature=").append(this.feature);
                sb.append(",threshold=").append(this.threshold.floatValue() - NODE_SPLIT_SLACK);
                if (this.missing != null) {
                    sb.append(",missing=").append(this.missing);
                }
                sb.append(",left=").append(this.left);
                sb.append(",right=").append(this.right);
                sb.append(')');
            }
            return sb.toString();
        }

        public RegressionTreeNode() {
        }
    }

    private RegressionTree createRegressionTree(Map<String, Object> map) {
        RegressionTree regressionTree = new RegressionTree();
        if (map != null) {
            SolrPluginUtils.invokeSetters(regressionTree, map.entrySet());
        }
        return regressionTree;
    }

    private RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
        RegressionTreeNode regressionTreeNode = new RegressionTreeNode();
        if (map != null) {
            SolrPluginUtils.invokeSetters(regressionTreeNode, map.entrySet());
        }
        return regressionTreeNode;
    }

    public void setIsNullSameAsZero(boolean z) {
        this.isNullSameAsZero = z;
    }

    public void setTrees(Object obj) {
        this.trees = new ArrayList();
        Iterator it = ((List) obj).iterator();
        while (it.hasNext()) {
            this.trees.add(createRegressionTree((Map) it.next()));
        }
    }

    public MultipleAdditiveTreesModel(String str, List<Feature> list, List<Normalizer> list2, String str2, List<Feature> list3, Map<String, Object> map) {
        super(str, list, list2, str2, list3, map);
        this.isNullSameAsZero = true;
        this.fname2index = new HashMap<>();
        for (int i = 0; i < list.size(); i++) {
            this.fname2index.put(list.get(i).getName(), Integer.valueOf(i));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.solr.ltr.model.LTRScoringModel
    public void validate() throws ModelException {
        super.validate();
        if (this.trees == null) {
            throw new ModelException("no trees declared for model " + this.name);
        }
        Iterator<RegressionTree> it = this.trees.iterator();
        while (it.hasNext()) {
            it.next().validate();
        }
    }

    @Override // org.apache.solr.ltr.model.LTRScoringModel
    public void normalizeFeaturesInPlace(float[] fArr) {
        normalizeFeaturesInPlace(fArr, this.isNullSameAsZero);
    }

    protected void normalizeFeaturesInPlace(float[] fArr, boolean z) {
        if (fArr.length != this.norms.size()) {
            throw new FeatureException("Must have normalizer for every feature");
        }
        if (z) {
            for (int i = 0; i < fArr.length; i++) {
                fArr[i] = this.norms.get(i).normalize(fArr[i]);
            }
            return;
        }
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (!Float.isNaN(fArr[i2])) {
                fArr[i2] = this.norms.get(i2).normalize(fArr[i2]);
            }
        }
    }

    @Override // org.apache.solr.ltr.model.LTRScoringModel
    public float score(float[] fArr) {
        float f = 0.0f;
        Iterator<RegressionTree> it = this.trees.iterator();
        while (it.hasNext()) {
            f += it.next().score(fArr);
        }
        return f;
    }

    private static float scoreNode(float[] fArr, RegressionTreeNode regressionTreeNode) {
        while (!regressionTreeNode.isLeaf()) {
            if (regressionTreeNode.featureIndex < 0 || regressionTreeNode.featureIndex >= fArr.length) {
                return 0.0f;
            }
            regressionTreeNode = fArr[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold.floatValue() ? regressionTreeNode.left : regressionTreeNode.right;
        }
        return regressionTreeNode.value;
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:17:0x0067. Please report as an issue. */
    private static float scoreNodeWithNullSupport(float[] fArr, RegressionTreeNode regressionTreeNode) {
        while (!regressionTreeNode.isLeaf()) {
            if (regressionTreeNode.featureIndex < 0 || regressionTreeNode.featureIndex >= fArr.length) {
                return 0.0f;
            }
            if (fArr[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold.floatValue()) {
                regressionTreeNode = regressionTreeNode.left;
            } else if (fArr[regressionTreeNode.featureIndex] > regressionTreeNode.threshold.floatValue()) {
                regressionTreeNode = regressionTreeNode.right;
            } else if (Float.isNaN(fArr[regressionTreeNode.featureIndex])) {
                String str = regressionTreeNode.missing;
                boolean z = -1;
                switch (str.hashCode()) {
                    case 3317767:
                        if (str.equals("left")) {
                            z = false;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case false:
                        regressionTreeNode = regressionTreeNode.left;
                        break;
                    default:
                        regressionTreeNode = regressionTreeNode.right;
                        break;
                }
            }
        }
        return regressionTreeNode.value;
    }

    private static void validateNode(RegressionTreeNode regressionTreeNode) throws ModelException {
        ArrayDeque arrayDeque = new ArrayDeque();
        arrayDeque.push(regressionTreeNode);
        while (!arrayDeque.isEmpty()) {
            RegressionTreeNode regressionTreeNode2 = (RegressionTreeNode) arrayDeque.pop();
            if (regressionTreeNode2.isLeaf()) {
                if (regressionTreeNode2.left != null || regressionTreeNode2.right != null) {
                    throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + regressionTreeNode2.left + " and right=" + regressionTreeNode2.right);
                }
            } else {
                if (null == regressionTreeNode2.threshold) {
                    throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
                }
                if (null == regressionTreeNode2.left) {
                    throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
                }
                arrayDeque.push(regressionTreeNode2.left);
                if (null == regressionTreeNode2.right) {
                    throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
                }
                arrayDeque.push(regressionTreeNode2.right);
            }
        }
    }

    private static String explainNode(float[] fArr, RegressionTreeNode regressionTreeNode) {
        StringBuilder sb = new StringBuilder();
        while (!regressionTreeNode.isLeaf()) {
            if (regressionTreeNode.featureIndex < 0 || regressionTreeNode.featureIndex >= fArr.length) {
                sb.append("'").append(regressionTreeNode.feature).append("' does not exist in FV, Return Zero");
                return sb.toString();
            }
            if (fArr[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold.floatValue()) {
                sb.append("'").append(regressionTreeNode.feature).append("':").append(fArr[regressionTreeNode.featureIndex]).append(" <= ").append(regressionTreeNode.threshold).append(", Go Left | ");
                regressionTreeNode = regressionTreeNode.left;
            } else if (fArr[regressionTreeNode.featureIndex] > regressionTreeNode.threshold.floatValue()) {
                sb.append("'").append(regressionTreeNode.feature).append("':").append(fArr[regressionTreeNode.featureIndex]).append(" > ").append(regressionTreeNode.threshold).append(", Go Right | ");
                regressionTreeNode = regressionTreeNode.right;
            } else if (Float.isNaN(fArr[regressionTreeNode.featureIndex])) {
                if (Objects.equals(regressionTreeNode.missing, "left")) {
                    sb.append("'").append(regressionTreeNode.feature).append("': NaN, Go Left | ");
                    regressionTreeNode = regressionTreeNode.left;
                } else {
                    sb.append("'").append(regressionTreeNode.feature).append("': NaN, Go Right | ");
                    regressionTreeNode = regressionTreeNode.right;
                }
            }
        }
        sb.append("val: ").append(regressionTreeNode.value);
        return sb.toString();
    }

    @Override // org.apache.solr.ltr.model.LTRScoringModel
    public Explanation explain(LeafReaderContext leafReaderContext, int i, float f, List<Explanation> list) {
        float[] fArr = new float[list.size()];
        int i2 = 0;
        Iterator<Explanation> it = list.iterator();
        while (it.hasNext()) {
            fArr[i2] = it.next().getValue().floatValue();
            i2++;
        }
        ArrayList arrayList = new ArrayList();
        int i3 = 0;
        for (RegressionTree regressionTree : this.trees) {
            arrayList.add(Explanation.match(Float.valueOf(regressionTree.score(fArr)), "tree " + i3 + " | " + regressionTree.explain(fArr), new Explanation[0]));
            i3++;
        }
        return Explanation.match(Float.valueOf(f), toString() + " model applied to features, sum of:", arrayList);
    }

    @Override // org.apache.solr.ltr.model.LTRScoringModel
    public String toString() {
        StringBuilder sb = new StringBuilder(getClass().getSimpleName());
        sb.append("(name=").append(getName());
        sb.append(",trees=[");
        for (int i = 0; i < this.trees.size(); i++) {
            if (i > 0) {
                sb.append(',');
            }
            sb.append(this.trees.get(i));
        }
        sb.append("])");
        return sb.toString();
    }
}
