package io.trino.sql.planner.iterative.rule.test;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.MoreCollectors;
import io.trino.Session;
import io.trino.cost.CachingCostProvider;
import io.trino.cost.CachingStatsProvider;
import io.trino.cost.CachingTableStatsProvider;
import io.trino.cost.CostCalculator;
import io.trino.cost.CostProvider;
import io.trino.cost.StatsAndCosts;
import io.trino.cost.StatsCalculator;
import io.trino.cost.StatsProvider;
import io.trino.execution.warnings.WarningCollector;
import io.trino.matching.Capture;
import io.trino.matching.Match;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.assertions.PlanAssert;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Memo;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.planprinter.PlanPrinter;
import io.trino.testing.LocalQueryRunner;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;
import org.testng.Assert;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/test/RuleAssert.class */
public class RuleAssert {
    private final Rule<?> rule;
    private final LocalQueryRunner queryRunner;
    private final StatsCalculator statsCalculator;
    private final Session session;
    private final PlanNode plan;
    private final TypeProvider types;
    private final PlanNodeIdAllocator idAllocator;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication.class */
    public static final class RuleApplication extends Record {
        private final Lookup lookup;
        private final StatsProvider statsProvider;
        private final TypeProvider types;
        private final Rule.Result result;

        private RuleApplication(Lookup lookup, StatsProvider statsProvider, TypeProvider typeProvider, Rule.Result result) {
            this.lookup = (Lookup) Objects.requireNonNull(lookup, "lookup is null");
            this.statsProvider = (StatsProvider) Objects.requireNonNull(statsProvider, "statsProvider is null");
            this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
            this.result = (Rule.Result) Objects.requireNonNull(result, "result is null");
        }

        private boolean wasRuleApplied() {
            return !this.result.isEmpty();
        }

        public PlanNode getTransformedPlan() {
            return (PlanNode) this.result.getTransformedPlan().orElseThrow(() -> {
                return new IllegalStateException("Rule did not produce transformed plan");
            });
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, RuleApplication.class), RuleApplication.class, "lookup;statsProvider;types;result", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->lookup:Lio/trino/sql/planner/iterative/Lookup;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->statsProvider:Lio/trino/cost/StatsProvider;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->types:Lio/trino/sql/planner/TypeProvider;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->result:Lio/trino/sql/planner/iterative/Rule$Result;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, RuleApplication.class), RuleApplication.class, "lookup;statsProvider;types;result", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->lookup:Lio/trino/sql/planner/iterative/Lookup;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->statsProvider:Lio/trino/cost/StatsProvider;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->types:Lio/trino/sql/planner/TypeProvider;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->result:Lio/trino/sql/planner/iterative/Rule$Result;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, RuleApplication.class, Object.class), RuleApplication.class, "lookup;statsProvider;types;result", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->lookup:Lio/trino/sql/planner/iterative/Lookup;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->statsProvider:Lio/trino/cost/StatsProvider;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->types:Lio/trino/sql/planner/TypeProvider;", "FIELD:Lio/trino/sql/planner/iterative/rule/test/RuleAssert$RuleApplication;->result:Lio/trino/sql/planner/iterative/Rule$Result;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public Lookup lookup() {
            return this.lookup;
        }

        public StatsProvider statsProvider() {
            return this.statsProvider;
        }

        public TypeProvider types() {
            return this.types;
        }

        public Rule.Result result() {
            return this.result;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RuleAssert(Rule<?> rule, LocalQueryRunner localQueryRunner, StatsCalculator statsCalculator, Session session, PlanNodeIdAllocator planNodeIdAllocator, PlanNode planNode, TypeProvider typeProvider) {
        this.rule = (Rule) Objects.requireNonNull(rule, "rule is null");
        this.queryRunner = (LocalQueryRunner) Objects.requireNonNull(localQueryRunner, "queryRunner is null");
        this.statsCalculator = (StatsCalculator) Objects.requireNonNull(statsCalculator, "statsCalculator is null");
        session.getRequiredTransactionId();
        this.session = session;
        this.idAllocator = (PlanNodeIdAllocator) Objects.requireNonNull(planNodeIdAllocator, "idAllocator is null");
        this.plan = (PlanNode) Objects.requireNonNull(planNode, "plan is null");
        this.types = (TypeProvider) Objects.requireNonNull(typeProvider, "types is null");
    }

    public void doesNotFire() {
        try {
            RuleApplication applyRule = applyRule();
            if (applyRule.wasRuleApplied()) {
                Assert.fail(String.format("Expected %s to not fire for:\n%s", this.rule, PlanPrinter.textLogicalPlan(this.plan, applyRule.types(), this.queryRunner.getMetadata(), this.queryRunner.getFunctionManager(), StatsAndCosts.empty(), this.session, 2, false)));
            }
        } finally {
            this.queryRunner.getMetadata().cleanupQuery(this.session);
            this.queryRunner.getTransactionManager().asyncAbort(this.session.getRequiredTransactionId());
        }
    }

    public void matches(PlanMatchPattern planMatchPattern) {
        try {
            RuleApplication applyRule = applyRule();
            if (!applyRule.wasRuleApplied()) {
                Assert.fail(String.format("%s did not fire for:\n%s", this.rule, formatPlan(this.plan, applyRule.types())));
            }
            PlanNode transformedPlan = applyRule.getTransformedPlan();
            if (transformedPlan == this.plan) {
                Assert.fail(String.format("%s: rule fired but return the original plan:\n%s\n", this.rule, formatPlan(this.plan, applyRule.types())));
            }
            if (!ImmutableSet.copyOf(this.plan.getOutputSymbols()).equals(ImmutableSet.copyOf(transformedPlan.getOutputSymbols()))) {
                Assert.fail(String.format("%s: output schema of transformed and original plans are not equivalent\n\texpected: %s\n\tactual:   %s\n", this.rule, this.plan.getOutputSymbols(), transformedPlan.getOutputSymbols()));
            }
            PlanAssert.assertPlan(this.session, this.queryRunner.getMetadata(), this.queryRunner.getFunctionManager(), applyRule.statsProvider(), new Plan(transformedPlan, applyRule.types(), StatsAndCosts.empty()), applyRule.lookup(), planMatchPattern);
            this.queryRunner.getMetadata().cleanupQuery(this.session);
            this.queryRunner.getTransactionManager().asyncAbort(this.session.getRequiredTransactionId());
        } catch (Throwable th) {
            this.queryRunner.getMetadata().cleanupQuery(this.session);
            this.queryRunner.getTransactionManager().asyncAbort(this.session.getRequiredTransactionId());
            throw th;
        }
    }

    private RuleApplication applyRule() {
        SymbolAllocator symbolAllocator = new SymbolAllocator(this.types.allTypes());
        Memo memo = new Memo(this.idAllocator, this.plan);
        Lookup from = Lookup.from(groupReference -> {
            return Stream.of(memo.resolve(groupReference));
        });
        return applyRule(this.rule, memo.getNode(memo.getRootGroup()), ruleContext(this.statsCalculator, this.queryRunner.getEstimatedExchangesCostCalculator(), symbolAllocator, memo, from, this.session));
    }

    private static <T> RuleApplication applyRule(Rule<T> rule, PlanNode planNode, Rule.Context context) {
        Capture newCapture = Capture.newCapture();
        Optional optional = (Optional) rule.getPattern().capturedAs(newCapture).match(planNode, context.getLookup()).collect(MoreCollectors.toOptional());
        return new RuleApplication(context.getLookup(), context.getStatsProvider(), context.getSymbolAllocator().getTypes(), (!rule.isEnabled(context.getSession()) || optional.isEmpty()) ? Rule.Result.empty() : rule.apply(((Match) optional.get()).capture(newCapture), ((Match) optional.get()).captures(), context));
    }

    private String formatPlan(PlanNode planNode, TypeProvider typeProvider) {
        CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(this.statsCalculator, this.session, typeProvider, new CachingTableStatsProvider(this.queryRunner.getMetadata(), this.session));
        return PlanPrinter.textLogicalPlan(planNode, typeProvider, this.queryRunner.getMetadata(), this.queryRunner.getFunctionManager(), StatsAndCosts.create(planNode, cachingStatsProvider, new CachingCostProvider(this.queryRunner.getCostCalculator(), cachingStatsProvider, this.session, typeProvider)), this.session, 2, false);
    }

    private Rule.Context ruleContext(StatsCalculator statsCalculator, CostCalculator costCalculator, final SymbolAllocator symbolAllocator, Memo memo, final Lookup lookup, final Session session) {
        final CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(statsCalculator, Optional.of(memo), lookup, session, symbolAllocator.getTypes(), new CachingTableStatsProvider(this.queryRunner.getMetadata(), session));
        final CachingCostProvider cachingCostProvider = new CachingCostProvider(costCalculator, cachingStatsProvider, Optional.of(memo), session, symbolAllocator.getTypes());
        return new Rule.Context() { // from class: io.trino.sql.planner.iterative.rule.test.RuleAssert.1
            public Lookup getLookup() {
                return lookup;
            }

            public PlanNodeIdAllocator getIdAllocator() {
                return RuleAssert.this.idAllocator;
            }

            public SymbolAllocator getSymbolAllocator() {
                return symbolAllocator;
            }

            public Session getSession() {
                return session;
            }

            public StatsProvider getStatsProvider() {
                return cachingStatsProvider;
            }

            public CostProvider getCostProvider() {
                return cachingCostProvider;
            }

            public void checkTimeoutNotExhausted() {
            }

            public WarningCollector getWarningCollector() {
                return WarningCollector.NOOP;
            }
        };
    }
}
