/*
 * Decompiled with CFR 0.152.
 */
package io.openlineage.spark3.agent.lifecycle.plan.column;

import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageBuilder;
import io.openlineage.spark.agent.util.ScalaConversionUtils;
import io.openlineage.spark.api.OpenLineageContext;
import io.openlineage.spark3.agent.lifecycle.plan.column.CustomCollectorsUtils;
import io.openlineage.spark3.agent.lifecycle.plan.column.JdbcColumnLineageCollector;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.ExpressionDependencyVisitor;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.IcebergMergeIntoDependencyVisitor;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.UnionDependencyVisitor;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.lang3.reflect.MethodUtils;
import org.apache.spark.sql.catalyst.expressions.ExprId;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project;
import org.apache.spark.sql.execution.datasources.LogicalRelation;
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.Seq;
import scala.runtime.BoxedUnit;

public class ExpressionDependencyCollector {
    private static final Logger log = LoggerFactory.getLogger(ExpressionDependencyCollector.class);
    private static final List<ExpressionDependencyVisitor> expressionDependencyVisitors = Arrays.asList(new UnionDependencyVisitor(), new IcebergMergeIntoDependencyVisitor());

    static void collect(OpenLineageContext context, LogicalPlan plan, ColumnLevelLineageBuilder builder) {
        plan.foreach(node -> {
            expressionDependencyVisitors.stream().filter(collector -> collector.isDefinedAt((LogicalPlan)node)).forEach(collector -> collector.apply((LogicalPlan)node, builder));
            CustomCollectorsUtils.collectExpressionDependencies(context, node, builder);
            LinkedList expressions = new LinkedList();
            if (node instanceof Project) {
                expressions.addAll(ScalaConversionUtils.fromSeq(((Project)node).projectList()));
            } else if (node instanceof Aggregate) {
                expressions.addAll(ScalaConversionUtils.fromSeq(((Aggregate)node).aggregateExpressions()));
            } else if (node instanceof LogicalRelation && ((LogicalRelation)node).relation() instanceof JDBCRelation) {
                JdbcColumnLineageCollector.extractExpressionsFromJDBC(node, builder);
            }
            expressions.stream().forEach(expr -> ExpressionDependencyCollector.traverseExpression((Expression)expr, expr.exprId(), builder));
            return BoxedUnit.UNIT;
        });
    }

    public static void traverseExpression(Expression expr, ExprId ancestorId, ColumnLevelLineageBuilder builder) {
        if (expr instanceof NamedExpression && !((NamedExpression)expr).exprId().equals((Object)ancestorId)) {
            builder.addDependency(ancestorId, ((NamedExpression)expr).exprId());
        }
        if (expr.children() != null) {
            ScalaConversionUtils.fromSeq(expr.children()).stream().forEach(child -> ExpressionDependencyCollector.traverseExpression(child, ancestorId, builder));
        }
        if (expr instanceof AggregateExpression) {
            AggregateExpression aggr = (AggregateExpression)expr;
            if (MethodUtils.getAccessibleMethod(AggregateExpression.class, "resultId", new Class[0]) != null) {
                builder.addDependency(ancestorId, aggr.resultId());
            } else {
                try {
                    Seq resultIds = (Seq)MethodUtils.invokeMethod(aggr, "resultIds");
                    ScalaConversionUtils.fromSeq(resultIds).stream().forEach(e -> builder.addDependency(ancestorId, (ExprId)e));
                }
                catch (Exception e2) {
                    log.warn("Failed extracting resultIds from AggregateExpression", (Throwable)e2);
                }
            }
        }
    }
}

