package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ReduceJoin;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.class */
public class TensorOptimizer extends Optimizer {
    private OptimizationReport report;

    @Override // com.yahoo.searchlib.rankingexpression.evaluation.Optimizer
    public void optimize(RankingExpression rankingExpression, ContextIndex contextIndex, OptimizationReport optimizationReport) {
        if (isEnabled()) {
            this.report = optimizationReport;
            rankingExpression.setRoot(optimize(rankingExpression.getRoot(), contextIndex));
            optimizationReport.note("Tensor expression optimization done");
        }
    }

    private ExpressionNode optimize(ExpressionNode expressionNode, ContextIndex contextIndex) {
        ExpressionNode optimizeReduceJoin = optimizeReduceJoin(expressionNode);
        return optimizeReduceJoin instanceof CompositeNode ? optimizeChildren((CompositeNode) optimizeReduceJoin, contextIndex) : optimizeReduceJoin;
    }

    private ExpressionNode optimizeChildren(CompositeNode compositeNode, ContextIndex contextIndex) {
        List<ExpressionNode> children = compositeNode.children();
        ArrayList arrayList = new ArrayList(children.size());
        Iterator<ExpressionNode> it = children.iterator();
        while (it.hasNext()) {
            arrayList.add(optimize(it.next(), contextIndex));
        }
        return compositeNode.setChildren(arrayList);
    }

    private ExpressionNode optimizeReduceJoin(ExpressionNode expressionNode) {
        if (!(expressionNode instanceof TensorFunctionNode)) {
            return expressionNode;
        }
        Reduce function = ((TensorFunctionNode) expressionNode).function();
        if (!(function instanceof Reduce)) {
            return expressionNode;
        }
        List<ExpressionNode> children = ((TensorFunctionNode) expressionNode).children();
        if (children.size() != 1) {
            return expressionNode;
        }
        ExpressionNode expressionNode2 = children.get(0);
        if (!(expressionNode2 instanceof TensorFunctionNode)) {
            return expressionNode;
        }
        Join function2 = ((TensorFunctionNode) expressionNode2).function();
        if (!(function2 instanceof Join)) {
            return expressionNode;
        }
        this.report.incMetric("Replaced reduce->join", 1);
        return new TensorFunctionNode(new ReduceJoin(function, function2));
    }
}
