package io.trino.execution.scheduler.faulttolerant;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.SetMultimap;
import com.google.common.collect.Sets;
import com.google.common.primitives.ImmutableIntArray;
import io.trino.execution.scheduler.faulttolerant.SplitAssigner;
import io.trino.metadata.Split;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.assertj.core.api.Assertions;

/* loaded from: input_file:io/trino/execution/scheduler/faulttolerant/SplitAssignerTester.class */
class SplitAssignerTester {
    private boolean noMoreTaskPartitions;
    private final Map<Integer, NodeRequirements> nodeRequirements = new HashMap();
    private final Map<Integer, SplitsMapping> splits = new HashMap();
    private final SetMultimap<Integer, PlanNodeId> noMoreSplits = HashMultimap.create();
    private final Set<Integer> sealedTaskPartitions = new HashSet();
    private Optional<List<TaskDescriptor>> taskDescriptors = Optional.empty();

    public Optional<List<TaskDescriptor>> getTaskDescriptors() {
        return this.taskDescriptors;
    }

    public synchronized int getTaskPartitionCount() {
        return this.nodeRequirements.size();
    }

    public synchronized NodeRequirements getNodeRequirements(int i) {
        NodeRequirements nodeRequirements = this.nodeRequirements.get(Integer.valueOf(i));
        Preconditions.checkArgument(nodeRequirements != null, "task partition not found: %s", i);
        return nodeRequirements;
    }

    public synchronized Set<Integer> getSplitIds(int i, PlanNodeId planNodeId) {
        return (Set) this.splits.getOrDefault(Integer.valueOf(i), SplitsMapping.EMPTY).getSplitsFlat(planNodeId).stream().map(split -> {
            return (TestingConnectorSplit) split.getConnectorSplit();
        }).map((v0) -> {
            return v0.getId();
        }).collect(ImmutableSet.toImmutableSet());
    }

    public synchronized ListMultimap<Integer, Integer> getSplitIdsBySourcePartition(int i, PlanNodeId planNodeId) {
        SplitsMapping orDefault = this.splits.getOrDefault(Integer.valueOf(i), SplitsMapping.EMPTY);
        ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
        orDefault.getSplits(planNodeId).forEach((num, split) -> {
            builder.put(num, Integer.valueOf(TestingConnectorSplit.getSplitId(split)));
        });
        return builder.build();
    }

    public synchronized boolean isNoMoreSplits(int i, PlanNodeId planNodeId) {
        return this.noMoreSplits.get(Integer.valueOf(i)).contains(planNodeId);
    }

    public synchronized boolean isSealed(int i) {
        return this.sealedTaskPartitions.contains(Integer.valueOf(i));
    }

    public synchronized boolean isNoMoreTaskPartitions() {
        return this.noMoreTaskPartitions;
    }

    public void checkContainsSplits(PlanNodeId planNodeId, Collection<Split> collection, boolean z) {
        Set set = (Set) collection.stream().map(TestingConnectorSplit::getSplitId).collect(Collectors.toSet());
        for (int i = 0; i < getTaskPartitionCount(); i++) {
            Set<Integer> splitIds = getSplitIds(i, planNodeId);
            if (z) {
                Assertions.assertThat(splitIds).containsAll(set);
            } else {
                set.removeAll(splitIds);
            }
        }
        if (z) {
            return;
        }
        Assertions.assertThat(set).isEmpty();
    }

    public void checkContainsSplits(PlanNodeId planNodeId, ListMultimap<Integer, Split> listMultimap, boolean z) {
        ListMultimap create;
        if (z) {
            create = ArrayListMultimap.create();
            create.putAll(0, buildSplitIds(listMultimap).values());
        } else {
            create = ArrayListMultimap.create(buildSplitIds(listMultimap));
        }
        for (int i = 0; i < getTaskPartitionCount(); i++) {
            ListMultimap<Integer, Integer> splitIdsBySourcePartition = getSplitIdsBySourcePartition(i, planNodeId);
            if (z) {
                org.assertj.guava.api.Assertions.assertThat(splitIdsBySourcePartition).containsAllEntriesOf(create);
            } else {
                ListMultimap listMultimap2 = create;
                Objects.requireNonNull(listMultimap2);
                splitIdsBySourcePartition.forEach((v1, v2) -> {
                    r1.remove(v1, v2);
                });
            }
        }
        if (z) {
            return;
        }
        org.assertj.guava.api.Assertions.assertThat(create).isEmpty();
    }

    private ListMultimap<Integer, Integer> buildSplitIds(ListMultimap<Integer, Split> listMultimap) {
        ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder();
        listMultimap.forEach((num, split) -> {
            builder.put(num, Integer.valueOf(TestingConnectorSplit.getSplitId(split)));
        });
        return builder.build();
    }

    public void update(SplitAssigner.AssignmentResult assignmentResult) {
        for (SplitAssigner.Partition partition : assignmentResult.partitionsAdded()) {
            Verify.verify(!this.noMoreTaskPartitions, "noMoreTaskPartitions is set", new Object[0]);
            Verify.verify(this.nodeRequirements.put(Integer.valueOf(partition.partitionId()), partition.nodeRequirements()) == null, "task partition already exist: %s", partition.partitionId());
        }
        for (SplitAssigner.PartitionUpdate partitionUpdate : assignmentResult.partitionUpdates()) {
            int partitionId = partitionUpdate.partitionId();
            Verify.verify(this.nodeRequirements.get(Integer.valueOf(partitionId)) != null, "task partition does not exist: %s", partitionId);
            Verify.verify(!this.sealedTaskPartitions.contains(Integer.valueOf(partitionId)), "task partition is sealed: %s", partitionId);
            PlanNodeId planNodeId = partitionUpdate.planNodeId();
            if (!partitionUpdate.splits().isEmpty()) {
                Verify.verify(!this.noMoreSplits.get(Integer.valueOf(partitionId)).contains(planNodeId), "noMoreSplits is set for task partition %s and plan node %s", partitionId, planNodeId);
                this.splits.merge(Integer.valueOf(partitionId), SplitsMapping.builder().addSplits(planNodeId, partitionUpdate.splits()).build(), (splitsMapping, splitsMapping2) -> {
                    return SplitsMapping.builder(splitsMapping).addMapping(splitsMapping2).build();
                });
            }
            if (partitionUpdate.noMoreSplits()) {
                this.noMoreSplits.put(Integer.valueOf(partitionId), planNodeId);
            }
        }
        ImmutableIntArray sealedPartitions = assignmentResult.sealedPartitions();
        Set<Integer> set = this.sealedTaskPartitions;
        Objects.requireNonNull(set);
        sealedPartitions.forEach((v1) -> {
            r1.add(v1);
        });
        if (assignmentResult.noMorePartitions()) {
            this.noMoreTaskPartitions = true;
        }
        checkFinished();
    }

    private synchronized void checkFinished() {
        if (this.noMoreTaskPartitions && this.sealedTaskPartitions.containsAll(this.nodeRequirements.keySet())) {
            Verify.verify(this.sealedTaskPartitions.equals(this.nodeRequirements.keySet()), "unknown sealed partitions: %s", Sets.difference(this.sealedTaskPartitions, this.nodeRequirements.keySet()));
            ImmutableList.Builder builder = ImmutableList.builder();
            for (Integer num : this.sealedTaskPartitions) {
                SplitsMapping orDefault = this.splits.getOrDefault(num, SplitsMapping.EMPTY);
                Verify.verify(this.noMoreSplits.get(num).containsAll(orDefault.getPlanNodeIds()), "no more split is missing for task partition %s: %s", num, Sets.difference(orDefault.getPlanNodeIds(), this.noMoreSplits.get(num)));
                builder.add(new TaskDescriptor(num.intValue(), orDefault, this.nodeRequirements.get(num)));
            }
            this.taskDescriptors = Optional.of(builder.build());
        }
    }
}
