package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.rule.Operator;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.class */
public class GBDTForestOptimizer extends Optimizer {
    private OptimizationReport report;
    private int currentTreesOptimized = 0;

    @Override // com.yahoo.searchlib.rankingexpression.evaluation.Optimizer
    public void optimize(RankingExpression rankingExpression, ContextIndex contextIndex, OptimizationReport optimizationReport) {
        if (isEnabled()) {
            this.report = optimizationReport;
            rankingExpression.setRoot(findAndOptimize(rankingExpression.getRoot()));
            optimizationReport.note("GBDT forest optimization done");
        }
    }

    private ExpressionNode findAndOptimize(ExpressionNode expressionNode) {
        ExpressionNode optimize = optimize(expressionNode);
        if (!(optimize instanceof CompositeNode)) {
            return optimize;
        }
        CompositeNode compositeNode = (CompositeNode) optimize;
        ArrayList arrayList = new ArrayList();
        Iterator<ExpressionNode> it = compositeNode.children().iterator();
        while (it.hasNext()) {
            arrayList.add(findAndOptimize(it.next()));
        }
        return compositeNode.setChildren(arrayList);
    }

    private ExpressionNode optimize(ExpressionNode expressionNode) {
        this.currentTreesOptimized = 0;
        ArrayList arrayList = new ArrayList();
        if (!optimize(expressionNode, arrayList)) {
            return expressionNode;
        }
        GBDTForestNode gBDTForestNode = new GBDTForestNode(toArray(arrayList));
        this.report.incMetric("Number of forests", 1);
        this.report.incMetric("GBDT trees optimized to forests", this.currentTreesOptimized);
        return gBDTForestNode;
    }

    private boolean optimize(ExpressionNode expressionNode, List<Double> list) {
        if (expressionNode instanceof GBDTNode) {
            addTo(list, (GBDTNode) expressionNode);
            this.currentTreesOptimized++;
            return true;
        }
        if (!(expressionNode instanceof OperationNode)) {
            return false;
        }
        OperationNode operationNode = (OperationNode) expressionNode;
        Iterator<Operator> it = operationNode.operators().iterator();
        while (it.hasNext()) {
            if (it.next() != Operator.plus) {
                return false;
            }
        }
        Iterator<ExpressionNode> it2 = operationNode.children().iterator();
        while (it2.hasNext()) {
            if (!optimize(it2.next(), list)) {
                return false;
            }
        }
        return true;
    }

    private void addTo(List<Double> list, GBDTNode gBDTNode) {
        list.add(Double.valueOf(gBDTNode.values().length));
        addAll(gBDTNode.values(), list);
    }

    private void addAll(double[] dArr, List<Double> list) {
        for (double d : dArr) {
            list.add(Double.valueOf(d));
        }
    }

    private double[] toArray(List<Double> list) {
        double[] dArr = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            dArr[i] = list.get(i).doubleValue();
        }
        return dArr;
    }
}
