package io.trino.execution.scheduler.faulttolerant;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.annotation.NotThreadSafe;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.spi.Node;
import io.trino.spi.connector.ConnectorBucketNodeMap;
import io.trino.spi.connector.ConnectorPartitioningHandle;
import io.trino.sql.planner.MergePartitioningHandle;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.SystemPartitioningHandle;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;

@NotThreadSafe
/* loaded from: input_file:io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningSchemeFactory.class */
public class FaultTolerantPartitioningSchemeFactory {
    private final NodePartitioningManager nodePartitioningManager;
    private final Session session;
    private final int maxPartitionCount;
    private final Map<PartitioningHandle, FaultTolerantPartitioningScheme> cache = new HashMap();

    public FaultTolerantPartitioningSchemeFactory(NodePartitioningManager nodePartitioningManager, Session session, int i) {
        this.nodePartitioningManager = (NodePartitioningManager) Objects.requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
        this.session = (Session) Objects.requireNonNull(session, "session is null");
        this.maxPartitionCount = i;
    }

    public FaultTolerantPartitioningScheme get(PartitioningHandle partitioningHandle, Optional<Integer> optional) {
        FaultTolerantPartitioningScheme faultTolerantPartitioningScheme = this.cache.get(partitioningHandle);
        if (faultTolerantPartitioningScheme == null) {
            faultTolerantPartitioningScheme = create(partitioningHandle, optional);
            this.cache.put(partitioningHandle, faultTolerantPartitioningScheme);
        } else if (optional.isPresent()) {
            faultTolerantPartitioningScheme = faultTolerantPartitioningScheme.withPartitionCount(optional.get().intValue());
        }
        return faultTolerantPartitioningScheme;
    }

    private FaultTolerantPartitioningScheme create(PartitioningHandle partitioningHandle, Optional<Integer> optional) {
        ConnectorPartitioningHandle connectorHandle = partitioningHandle.getConnectorHandle();
        if (connectorHandle instanceof MergePartitioningHandle) {
            return ((MergePartitioningHandle) connectorHandle).getFaultTolerantPartitioningScheme(partitioningHandle2 -> {
                return get(partitioningHandle2, optional);
            });
        }
        if (partitioningHandle.equals(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) || partitioningHandle.equals(SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION)) {
            return createSystemSchema(optional.orElse(Integer.valueOf(this.maxPartitionCount)).intValue());
        }
        if (!partitioningHandle.getCatalogHandle().isPresent()) {
            return new FaultTolerantPartitioningScheme(1, Optional.empty(), Optional.empty(), Optional.empty());
        }
        Optional<ConnectorBucketNodeMap> connectorBucketNodeMap = this.nodePartitioningManager.getConnectorBucketNodeMap(this.session, partitioningHandle);
        if (connectorBucketNodeMap.isEmpty()) {
            return createSystemSchema(optional.orElse(Integer.valueOf(this.maxPartitionCount)).intValue());
        }
        return createConnectorSpecificSchema(optional.orElse(Integer.valueOf(this.maxPartitionCount)).intValue(), connectorBucketNodeMap.get(), this.nodePartitioningManager.getSplitToBucket(this.session, partitioningHandle));
    }

    private static FaultTolerantPartitioningScheme createSystemSchema(int i) {
        return new FaultTolerantPartitioningScheme(i, Optional.of(IntStream.range(0, i).toArray()), Optional.empty(), Optional.empty());
    }

    private static FaultTolerantPartitioningScheme createConnectorSpecificSchema(int i, ConnectorBucketNodeMap connectorBucketNodeMap, ToIntFunction<Split> toIntFunction) {
        return connectorBucketNodeMap.hasFixedMapping() ? createFixedConnectorSpecificSchema(connectorBucketNodeMap.getFixedMapping(), toIntFunction) : createArbitraryConnectorSpecificSchema(i, connectorBucketNodeMap.getBucketCount(), toIntFunction);
    }

    private static FaultTolerantPartitioningScheme createFixedConnectorSpecificSchema(List<Node> list, ToIntFunction<Split> toIntFunction) {
        int size = list.size();
        int[] iArr = new int[size];
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            InternalNode internalNode = (InternalNode) list.get(i);
            Integer num = (Integer) hashMap.get(internalNode);
            if (num == null) {
                num = Integer.valueOf(arrayList.size());
                hashMap.put(internalNode, num);
                arrayList.add(internalNode);
            }
            iArr[i] = num.intValue();
        }
        return new FaultTolerantPartitioningScheme(arrayList.size(), Optional.of(iArr), Optional.of(toIntFunction), Optional.of(ImmutableList.copyOf(arrayList)));
    }

    private static FaultTolerantPartitioningScheme createArbitraryConnectorSpecificSchema(int i, int i2, ToIntFunction<Split> toIntFunction) {
        int[] iArr = new int[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            iArr[i3] = i3 % i;
        }
        return new FaultTolerantPartitioningScheme(i, Optional.of(iArr), Optional.of(toIntFunction), Optional.empty());
    }
}
