package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.ExpressionSymbolInliner;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.class */
public class PushProjectionThroughExchange implements Rule<ProjectNode> {
    private static final Capture<ExchangeNode> CHILD = Capture.newCapture();
    private static final Pattern<ProjectNode> PATTERN = Patterns.project().matching(projectNode -> {
        return !isSymbolToSymbolProjection(projectNode);
    }).with(Patterns.source().matching(Patterns.exchange().capturedAs(CHILD)));

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<ProjectNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(ProjectNode projectNode, Captures captures, Rule.Context context) {
        ExchangeNode exchangeNode = (ExchangeNode) captures.get(CHILD);
        Set<Symbol> columns = exchangeNode.getPartitioningScheme().getPartitioning().getColumns();
        ImmutableList.Builder builder = ImmutableList.builder();
        ImmutableList.Builder builder2 = ImmutableList.builder();
        for (int i = 0; i < exchangeNode.getSources().size(); i++) {
            Map<Symbol, Symbol> mapExchangeOutputToInput = mapExchangeOutputToInput(exchangeNode, i);
            Assignments.Builder builder3 = Assignments.builder();
            ImmutableList.Builder builder4 = ImmutableList.builder();
            Stream<Symbol> stream = columns.stream();
            Objects.requireNonNull(mapExchangeOutputToInput);
            stream.map((v1) -> {
                return r1.get(v1);
            }).forEach(symbol -> {
                builder3.putIdentity(symbol);
                builder4.add(symbol);
            });
            Optional<Symbol> hashColumn = exchangeNode.getPartitioningScheme().getHashColumn();
            Objects.requireNonNull(mapExchangeOutputToInput);
            hashColumn.map((v1) -> {
                return r1.get(v1);
            }).ifPresent(symbol2 -> {
                builder3.putIdentity(symbol2);
                builder4.add(symbol2);
            });
            if (exchangeNode.getOrderingScheme().isPresent()) {
                Stream<Symbol> filter = exchangeNode.getOrderingScheme().get().getOrderBy().stream().filter(symbol3 -> {
                    return !columns.contains(symbol3);
                });
                Objects.requireNonNull(mapExchangeOutputToInput);
                filter.map((v1) -> {
                    return r1.get(v1);
                }).forEach(symbol4 -> {
                    builder3.putIdentity(symbol4);
                    builder4.add(symbol4);
                });
            }
            ImmutableSet.Builder builder5 = ImmutableSet.builder();
            Objects.requireNonNull(builder5);
            columns.forEach((v1) -> {
                r1.add(v1);
            });
            Optional<Symbol> hashColumn2 = exchangeNode.getPartitioningScheme().getHashColumn();
            Objects.requireNonNull(builder5);
            hashColumn2.ifPresent((v1) -> {
                r1.add(v1);
            });
            exchangeNode.getOrderingScheme().ifPresent(orderingScheme -> {
                builder5.addAll(orderingScheme.getOrderBy());
            });
            ImmutableSet build = builder5.build();
            Map map = (Map) mapExchangeOutputToInput.entrySet().stream().collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return ((Symbol) entry.getValue()).toSymbolReference();
            }));
            for (Map.Entry<Symbol, Expression> entry2 : projectNode.getAssignments().entrySet()) {
                if (!build.contains(entry2.getKey())) {
                    Expression inlineSymbols = ExpressionSymbolInliner.inlineSymbols((Map<Symbol, ? extends Expression>) map, entry2.getValue());
                    Symbol newSymbol = context.getSymbolAllocator().newSymbol(inlineSymbols, context.getSymbolAllocator().getTypes().get(entry2.getKey()));
                    builder3.put(newSymbol, inlineSymbols);
                    builder4.add(newSymbol);
                }
            }
            builder.add(new ProjectNode(context.getIdAllocator().getNextId(), exchangeNode.getSources().get(i), builder3.build()));
            builder2.add(builder4.build());
        }
        ImmutableList.Builder builder6 = ImmutableList.builder();
        Objects.requireNonNull(builder6);
        columns.forEach((v1) -> {
            r1.add(v1);
        });
        Optional<Symbol> hashColumn3 = exchangeNode.getPartitioningScheme().getHashColumn();
        Objects.requireNonNull(builder6);
        hashColumn3.ifPresent((v1) -> {
            r1.add(v1);
        });
        if (exchangeNode.getOrderingScheme().isPresent()) {
            Stream<Symbol> filter2 = exchangeNode.getOrderingScheme().get().getOrderBy().stream().filter(symbol5 -> {
                return !columns.contains(symbol5);
            });
            Objects.requireNonNull(builder6);
            filter2.forEach((v1) -> {
                r1.add(v1);
            });
        }
        ImmutableSet copyOf = ImmutableSet.copyOf(builder6.build());
        for (Map.Entry<Symbol, Expression> entry3 : projectNode.getAssignments().entrySet()) {
            if (!copyOf.contains(entry3.getKey())) {
                builder6.add(entry3.getKey());
            }
        }
        ExchangeNode exchangeNode2 = new ExchangeNode(exchangeNode.getId(), exchangeNode.getType(), exchangeNode.getScope(), new PartitioningScheme(exchangeNode.getPartitioningScheme().getPartitioning(), builder6.build(), exchangeNode.getPartitioningScheme().getHashColumn(), exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), exchangeNode.getPartitioningScheme().getBucketToPartition(), exchangeNode.getPartitioningScheme().getPartitionCount()), builder.build(), builder2.build(), exchangeNode.getOrderingScheme());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), exchangeNode2, ImmutableSet.copyOf(projectNode.getOutputSymbols())).orElse(exchangeNode2));
    }

    private static boolean isSymbolToSymbolProjection(ProjectNode projectNode) {
        Stream<Expression> stream = projectNode.getAssignments().getExpressions().stream();
        Class<SymbolReference> cls = SymbolReference.class;
        Objects.requireNonNull(SymbolReference.class);
        return stream.allMatch((v1) -> {
            return r1.isInstance(v1);
        });
    }

    private static Map<Symbol, Symbol> mapExchangeOutputToInput(ExchangeNode exchangeNode, int i) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (int i2 = 0; i2 < exchangeNode.getOutputSymbols().size(); i2++) {
            builder.put(exchangeNode.getOutputSymbols().get(i2), exchangeNode.getInputs().get(i).get(i2));
        }
        return builder.buildOrThrow();
    }
}
