package ai.vespa.rankingexpression.importer.lightgbm;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.json.Jackson;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/vespa/rankingexpression/importer/lightgbm/LightGBMParser.class */
class LightGBMParser {
    private final String objective;
    private final List<LightGBMNode> nodes;
    private final List<String> featureNames;
    private final Map<Integer, List<String>> categoryValues;
    private Map<Integer, Boolean> categoricalIntegerFeatures = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public LightGBMParser(String str) throws IOException {
        ObjectMapper configure = Jackson.createMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        JsonNode readTree = configure.readTree(new File(str));
        this.objective = readTree.get("objective").asText("regression");
        this.featureNames = parseFeatureNames(readTree);
        this.nodes = parseTrees(configure, readTree);
        this.categoryValues = parseCategoryValues(readTree);
    }

    private List<String> parseFeatureNames(JsonNode jsonNode) {
        ArrayList arrayList = new ArrayList();
        Iterator it = jsonNode.get("feature_names").iterator();
        while (it.hasNext()) {
            arrayList.add(((JsonNode) it.next()).textValue());
        }
        return arrayList;
    }

    private List<LightGBMNode> parseTrees(ObjectMapper objectMapper, JsonNode jsonNode) throws JsonProcessingException {
        ArrayList arrayList = new ArrayList();
        Iterator it = jsonNode.get("tree_info").iterator();
        while (it.hasNext()) {
            arrayList.add((LightGBMNode) objectMapper.treeToValue(((JsonNode) it.next()).get("tree_structure"), LightGBMNode.class));
        }
        return arrayList;
    }

    private Map<Integer, List<String>> parseCategoryValues(JsonNode jsonNode) {
        Integer num;
        HashMap hashMap = new HashMap();
        this.categoricalIntegerFeatures = new HashMap();
        TreeSet treeSet = new TreeSet();
        ArrayList<String> arrayList = new ArrayList();
        if (jsonNode.has("feature_names") && jsonNode.get("feature_names").isArray()) {
            jsonNode.get("feature_names").forEach(jsonNode2 -> {
                arrayList.add(jsonNode2.asText());
            });
        }
        HashMap hashMap2 = new HashMap();
        for (int i = 0; i < arrayList.size(); i++) {
            hashMap2.put((String) arrayList.get(i), Integer.valueOf(i));
        }
        JsonNode jsonNode3 = jsonNode.get("feature_infos");
        if (jsonNode3 == null || !jsonNode3.isObject()) {
            this.nodes.forEach(lightGBMNode -> {
                findCategoricalFeatures(lightGBMNode, treeSet);
            });
        } else {
            for (String str : arrayList) {
                JsonNode jsonNode4 = jsonNode3.get(str);
                if (jsonNode4 != null && jsonNode4.has("values") && jsonNode4.get("values").isArray() && !jsonNode4.get("values").isEmpty() && (num = (Integer) hashMap2.get(str)) != null) {
                    treeSet.add(num);
                }
            }
        }
        Iterator it = jsonNode.get("pandas_categorical").iterator();
        Iterator it2 = treeSet.iterator();
        while (it.hasNext() && it2.hasNext()) {
            ArrayList arrayList2 = new ArrayList();
            JsonNode jsonNode5 = (JsonNode) it.next();
            int intValue = ((Integer) it2.next()).intValue();
            boolean z = true;
            Iterator it3 = jsonNode5.iterator();
            while (true) {
                if (!it3.hasNext()) {
                    break;
                }
                if (!((JsonNode) it3.next()).isInt()) {
                    z = false;
                    break;
                }
            }
            this.categoricalIntegerFeatures.put(Integer.valueOf(intValue), Boolean.valueOf(z));
            jsonNode5.forEach(jsonNode6 -> {
                if (jsonNode6.isTextual()) {
                    arrayList2.add(jsonNode6.textValue());
                } else {
                    arrayList2.add(jsonNode6.asText());
                }
            });
            hashMap.put(Integer.valueOf(intValue), arrayList2);
        }
        return hashMap;
    }

    private void findCategoricalFeatures(LightGBMNode lightGBMNode, Set<Integer> set) {
        if (lightGBMNode == null || lightGBMNode.isLeaf()) {
            return;
        }
        if (lightGBMNode.getDecision_type().equals("==")) {
            set.add(Integer.valueOf(lightGBMNode.getSplit_feature()));
        }
        findCategoricalFeatures(lightGBMNode.getLeft_child(), set);
        findCategoricalFeatures(lightGBMNode.getRight_child(), set);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public String toRankingExpression() {
        return applyObjective((String) this.nodes.stream().map(this::nodeToRankingExpression).collect(Collectors.joining(" + \n")));
    }

    private String applyObjective(String str) {
        return (this.objective.startsWith("binary") || this.objective.equals("cross_entropy")) ? "sigmoid(" + str + ")" : (this.objective.equals("poisson") || this.objective.equals("gamma") || this.objective.equals("tweedie")) ? "exp(" + str + ")" : str;
    }

    private String nodeToRankingExpression(LightGBMNode lightGBMNode) {
        String str;
        if (lightGBMNode.isLeaf()) {
            return Double.toString(lightGBMNode.getLeaf_value());
        }
        String str2 = this.featureNames.get(lightGBMNode.getSplit_feature());
        if (lightGBMNode.getDecision_type().equals("==")) {
            String transformCategoryIndexesToValues = transformCategoryIndexesToValues(lightGBMNode);
            str = lightGBMNode.isDefault_left() ? "isNan(" + str2 + ") || (" + str2 + " in [ " + transformCategoryIndexesToValues + "])" : str2 + " in [" + transformCategoryIndexesToValues + "]";
        } else {
            double parseDouble = Double.parseDouble(lightGBMNode.getThreshold());
            str = lightGBMNode.isDefault_left() ? "!(" + str2 + " >= " + parseDouble + ")" : str2 + " < " + parseDouble;
        }
        return "if (" + str + ", " + nodeToRankingExpression(lightGBMNode.getLeft_child()) + ", " + nodeToRankingExpression(lightGBMNode.getRight_child()) + ")";
    }

    private String transformCategoryIndexesToValues(LightGBMNode lightGBMNode) {
        int split_feature = lightGBMNode.getSplit_feature();
        boolean booleanValue = this.categoricalIntegerFeatures.getOrDefault(Integer.valueOf(split_feature), false).booleanValue();
        return (String) Arrays.stream(lightGBMNode.getThreshold().split("\\|\\|")).map(str -> {
            String transformCategoryIndexToValue = transformCategoryIndexToValue(split_feature, str);
            return booleanValue ? transformCategoryIndexToValue : "\"" + transformCategoryIndexToValue + "\"";
        }).collect(Collectors.joining(","));
    }

    private String transformCategoryIndexToValue(int i, String str) {
        return !this.categoryValues.containsKey(Integer.valueOf(i)) ? str : this.categoryValues.get(Integer.valueOf(i)).get(Integer.parseInt(str));
    }
}
