package org.apache.seatunnel.engine.server.dag.execution;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.apache.seatunnel.api.table.type.MultipleRowType;
import org.apache.seatunnel.api.table.type.SeaTunnelDataType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.engine.common.config.server.CheckpointConfig;
import org.apache.seatunnel.engine.common.utils.IdGenerator;
import org.apache.seatunnel.engine.core.dag.actions.AbstractAction;
import org.apache.seatunnel.engine.core.dag.actions.Action;
import org.apache.seatunnel.engine.core.dag.actions.ShuffleAction;
import org.apache.seatunnel.engine.core.dag.actions.ShuffleConfig;
import org.apache.seatunnel.engine.core.dag.actions.ShuffleMultipleRowStrategy;
import org.apache.seatunnel.engine.core.dag.actions.SinkAction;
import org.apache.seatunnel.engine.core.dag.actions.SinkConfig;
import org.apache.seatunnel.engine.core.dag.actions.SourceAction;
import org.apache.seatunnel.engine.core.dag.actions.TransformAction;
import org.apache.seatunnel.engine.core.dag.actions.TransformChainAction;
import org.apache.seatunnel.engine.core.dag.actions.UnknownActionException;
import org.apache.seatunnel.engine.core.dag.logical.LogicalDag;
import org.apache.seatunnel.engine.core.dag.logical.LogicalEdge;
import org.apache.seatunnel.engine.core.dag.logical.LogicalVertex;
import org.apache.seatunnel.engine.core.job.JobImmutableInformation;
import org.apache.seatunnel.shade.com.typesafe.config.ConfigParseOptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/seatunnel/engine/server/dag/execution/ExecutionPlanGenerator.class */
public class ExecutionPlanGenerator {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) ExecutionPlanGenerator.class);
    private final LogicalDag logicalPlan;
    private final JobImmutableInformation jobImmutableInformation;
    private final CheckpointConfig checkpointConfig;
    private final IdGenerator idGenerator = new IdGenerator();

    public ExecutionPlanGenerator(@NonNull LogicalDag logicalDag, @NonNull JobImmutableInformation jobImmutableInformation, @NonNull CheckpointConfig checkpointConfig) {
        if (logicalDag == null) {
            throw new NullPointerException("logicalPlan is marked non-null but is null");
        }
        if (jobImmutableInformation == null) {
            throw new NullPointerException("jobImmutableInformation is marked non-null but is null");
        }
        if (checkpointConfig == null) {
            throw new NullPointerException("checkpointConfig is marked non-null but is null");
        }
        Preconditions.checkArgument(logicalDag.getEdges().size() > 0, "ExecutionPlan Builder must have LogicalPlan.");
        this.logicalPlan = logicalDag;
        this.jobImmutableInformation = jobImmutableInformation;
        this.checkpointConfig = checkpointConfig;
    }

    public ExecutionPlan generate() {
        log.debug("Generate execution plan using logical plan:");
        Set<ExecutionEdge> generateExecutionEdges = generateExecutionEdges(this.logicalPlan.getEdges());
        log.debug("Phase 1: generate execution edge list {}", generateExecutionEdges);
        Set<ExecutionEdge> generateShuffleEdges = generateShuffleEdges(generateExecutionEdges);
        log.debug("Phase 2: generate shuffle edge list {}", generateShuffleEdges);
        Set<ExecutionEdge> generateTransformChainEdges = generateTransformChainEdges(generateShuffleEdges);
        log.debug("Phase 3: generate transform chain edge list {}", generateTransformChainEdges);
        List<Pipeline> generatePipelines = generatePipelines(generateTransformChainEdges);
        log.debug("Phase 4: generate pipeline list {}", generatePipelines);
        ExecutionPlan executionPlan = new ExecutionPlan(generatePipelines, this.jobImmutableInformation);
        log.debug("Phase 5: generate execution plan: {}", executionPlan);
        return executionPlan;
    }

    public static Action recreateAction(Action action, Long l, int i) {
        AbstractAction transformChainAction;
        if (action instanceof ShuffleAction) {
            transformChainAction = new ShuffleAction(l.longValue(), action.getName(), ((ShuffleAction) action).getConfig());
        } else if (action instanceof SinkAction) {
            transformChainAction = new SinkAction(l.longValue(), action.getName(), new ArrayList(), ((SinkAction) action).getSink(), action.getJarUrls(), (SinkConfig) action.getConfig());
        } else if (action instanceof SourceAction) {
            transformChainAction = new SourceAction(l.longValue(), action.getName(), ((SourceAction) action).getSource(), action.getJarUrls());
        } else if (action instanceof TransformAction) {
            transformChainAction = new TransformAction(l.longValue(), action.getName(), ((TransformAction) action).getTransform(), action.getJarUrls());
        } else {
            if (!(action instanceof TransformChainAction)) {
                throw new UnknownActionException(action);
            }
            transformChainAction = new TransformChainAction(l.longValue(), action.getName(), action.getJarUrls(), ((TransformChainAction) action).getTransforms());
        }
        transformChainAction.setParallelism(i);
        return transformChainAction;
    }

    private Set<ExecutionEdge> generateExecutionEdges(Set<LogicalEdge> set) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        HashMap hashMap = new HashMap();
        ArrayList<LogicalEdge> arrayList = new ArrayList(set);
        Collections.sort(arrayList, (logicalEdge, logicalEdge2) -> {
            if (logicalEdge.getInputVertexId() != logicalEdge2.getInputVertexId()) {
                return logicalEdge.getInputVertexId().longValue() > logicalEdge2.getInputVertexId().longValue() ? 1 : -1;
            }
            if (logicalEdge.getTargetVertexId() != logicalEdge2.getTargetVertexId()) {
                return logicalEdge.getTargetVertexId().longValue() > logicalEdge2.getTargetVertexId().longValue() ? 1 : -1;
            }
            return 0;
        });
        for (LogicalEdge logicalEdge3 : arrayList) {
            LogicalVertex inputVertex = logicalEdge3.getInputVertex();
            ExecutionVertex executionVertex = (ExecutionVertex) hashMap.computeIfAbsent(inputVertex.getVertexId(), l -> {
                long nextId = this.idGenerator.getNextId();
                return new ExecutionVertex(Long.valueOf(nextId), recreateAction(inputVertex.getAction(), Long.valueOf(nextId), inputVertex.getParallelism()), inputVertex.getParallelism());
            });
            LogicalVertex targetVertex = logicalEdge3.getTargetVertex();
            linkedHashSet.add(new ExecutionEdge(executionVertex, (ExecutionVertex) hashMap.computeIfAbsent(targetVertex.getVertexId(), l2 -> {
                long nextId = this.idGenerator.getNextId();
                return new ExecutionVertex(Long.valueOf(nextId), recreateAction(targetVertex.getAction(), Long.valueOf(nextId), targetVertex.getParallelism()), targetVertex.getParallelism());
            })));
        }
        return linkedHashSet;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Set<ExecutionEdge> generateShuffleEdges(Set<ExecutionEdge> set) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        HashSet hashSet = new HashSet();
        set.forEach(executionEdge -> {
            ExecutionVertex leftVertex = executionEdge.getLeftVertex();
            ExecutionVertex rightVertex = executionEdge.getRightVertex();
            if (leftVertex.getAction() instanceof SourceAction) {
                hashSet.add(leftVertex);
            }
            ((List) linkedHashMap.computeIfAbsent(leftVertex.getVertexId(), l -> {
                return new ArrayList();
            })).add(rightVertex);
        });
        if (hashSet.size() != 1) {
            return set;
        }
        ExecutionVertex executionVertex = (ExecutionVertex) hashSet.stream().findFirst().get();
        SourceAction sourceAction = (SourceAction) executionVertex.getAction();
        SeaTunnelDataType producedType = sourceAction.getSource().getProducedType();
        if (!SqlType.MULTIPLE_ROW.equals(producedType.getSqlType())) {
            return set;
        }
        List<ExecutionVertex> list = (List) linkedHashMap.get(executionVertex.getVertexId());
        Preconditions.checkArgument(!list.stream().filter(executionVertex2 -> {
            return !(executionVertex2.getAction() instanceof SinkAction);
        }).findFirst().isPresent());
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        ShuffleConfig build = ShuffleConfig.builder().shuffleStrategy(((ShuffleMultipleRowStrategy.ShuffleMultipleRowStrategyBuilder) ((ShuffleMultipleRowStrategy.ShuffleMultipleRowStrategyBuilder) ((ShuffleMultipleRowStrategy.ShuffleMultipleRowStrategyBuilder) ShuffleMultipleRowStrategy.builder().jobId(this.jobImmutableInformation.getJobId())).inputPartitions(sourceAction.getParallelism())).inputRowType((MultipleRowType) MultipleRowType.class.cast(producedType)).queueEmptyQueueTtl((int) (this.checkpointConfig.getCheckpointInterval() * 3))).build()).build();
        long nextId = this.idGenerator.getNextId();
        ShuffleAction shuffleAction = new ShuffleAction(nextId, String.format("Shuffle [%s]", sourceAction.getName()), build);
        shuffleAction.setParallelism(sourceAction.getParallelism());
        ExecutionVertex executionVertex3 = new ExecutionVertex(Long.valueOf(nextId), shuffleAction, shuffleAction.getParallelism());
        linkedHashSet.add(new ExecutionEdge(executionVertex, executionVertex3));
        for (ExecutionVertex executionVertex4 : list) {
            executionVertex4.setParallelism(1);
            executionVertex4.getAction().setParallelism(1);
            linkedHashSet.add(new ExecutionEdge(executionVertex3, executionVertex4));
        }
        return linkedHashSet;
    }

    private Set<ExecutionEdge> generateTransformChainEdges(Set<ExecutionEdge> set) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashSet<ExecutionVertex> hashSet = new HashSet();
        set.forEach(executionEdge -> {
            ExecutionVertex leftVertex = executionEdge.getLeftVertex();
            ExecutionVertex rightVertex = executionEdge.getRightVertex();
            if (leftVertex.getAction() instanceof SourceAction) {
                hashSet.add(leftVertex);
            }
            ((List) hashMap.computeIfAbsent(rightVertex.getVertexId(), l -> {
                return new ArrayList();
            })).add(leftVertex);
            ((List) hashMap2.computeIfAbsent(leftVertex.getVertexId(), l2 -> {
                return new ArrayList();
            })).add(rightVertex);
        });
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        for (ExecutionVertex executionVertex : hashSet) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(executionVertex);
            for (int i = 0; i < arrayList.size(); i++) {
                ExecutionVertex executionVertex2 = (ExecutionVertex) arrayList.get(i);
                fillChainedTransformExecutionVertex(executionVertex2, hashMap4, hashMap3, set, Collections.unmodifiableMap(hashMap), Collections.unmodifiableMap(hashMap2));
                if (hashMap2.containsKey(executionVertex2.getVertexId())) {
                    arrayList.addAll((Collection) hashMap2.get(executionVertex2.getVertexId()));
                }
            }
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (ExecutionEdge executionEdge2 : set) {
            ExecutionVertex leftVertex = executionEdge2.getLeftVertex();
            ExecutionVertex rightVertex = executionEdge2.getRightVertex();
            boolean z = false;
            if (hashMap4.containsKey(leftVertex.getVertexId())) {
                z = true;
                leftVertex = hashMap3.get(hashMap4.get(leftVertex.getVertexId()));
            }
            if (hashMap4.containsKey(rightVertex.getVertexId())) {
                z = true;
                rightVertex = hashMap3.get(hashMap4.get(rightVertex.getVertexId()));
            }
            if (z) {
                executionEdge2 = new ExecutionEdge(leftVertex, rightVertex);
            }
            linkedHashSet.add(executionEdge2);
        }
        return linkedHashSet;
    }

    private void fillChainedTransformExecutionVertex(ExecutionVertex executionVertex, Map<Long, Long> map, Map<Long, ExecutionVertex> map2, Set<ExecutionEdge> set, Map<Long, List<ExecutionVertex>> map3, Map<Long, List<ExecutionVertex>> map4) {
        if (map.containsKey(executionVertex.getVertexId())) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        collectChainedVertices(executionVertex, arrayList, set, map3, map4);
        if (arrayList.size() > 0) {
            long nextId = this.idGenerator.getNextId();
            ArrayList arrayList2 = new ArrayList(arrayList.size());
            ArrayList arrayList3 = new ArrayList(arrayList.size());
            HashSet hashSet = new HashSet();
            arrayList.stream().peek(executionVertex2 -> {
            }).map((v0) -> {
                return v0.getAction();
            }).map(action -> {
                return (TransformAction) action;
            }).forEach(transformAction -> {
                arrayList2.add(transformAction.getTransform());
                hashSet.addAll(transformAction.getJarUrls());
                arrayList3.add(transformAction.getName());
            });
            TransformChainAction transformChainAction = new TransformChainAction(nextId, String.format("TransformChain[%s]", String.join(ConfigParseOptions.PATH_TOKEN_SEPARATOR, arrayList3)), hashSet, arrayList2);
            transformChainAction.setParallelism(executionVertex.getAction().getParallelism());
            ExecutionVertex executionVertex3 = new ExecutionVertex(Long.valueOf(nextId), transformChainAction, executionVertex.getParallelism());
            map2.put(Long.valueOf(nextId), executionVertex3);
            map.put(executionVertex.getVertexId(), executionVertex3.getVertexId());
        }
    }

    private void collectChainedVertices(ExecutionVertex executionVertex, List<ExecutionVertex> list, Set<ExecutionEdge> set, Map<Long, List<ExecutionVertex>> map, Map<Long, List<ExecutionVertex>> map2) {
        if (executionVertex.getAction() instanceof TransformAction) {
            if (list.size() == 0) {
                list.add(executionVertex);
            } else {
                if (map.get(executionVertex.getVertexId()).size() != 1) {
                    return;
                }
                set.remove(new ExecutionEdge(list.get(list.size() - 1), executionVertex));
                list.add(executionVertex);
            }
            if (map2.get(executionVertex.getVertexId()).size() == 1) {
                collectChainedVertices(map2.get(executionVertex.getVertexId()).get(0), list, set, map, map2);
            }
        }
    }

    private List<Pipeline> generatePipelines(Set<ExecutionEdge> set) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (ExecutionEdge executionEdge : set) {
            linkedHashSet.add(executionEdge.getLeftVertex());
            linkedHashSet.add(executionEdge.getRightVertex());
        }
        List<Pipeline> generatePipelines = new PipelineGenerator(linkedHashSet, new ArrayList(set)).generatePipelines();
        long j = 0;
        HashSet hashSet = new HashSet();
        for (Pipeline pipeline : generatePipelines) {
            Integer id = pipeline.getId();
            Iterator<ExecutionVertex> it = pipeline.getVertexes().values().iterator();
            while (it.hasNext()) {
                Action action = it.next().getAction();
                String format = String.format("pipeline-%s [%s]", id, action.getName());
                action.setName(format);
                hashSet.add(format);
                j++;
            }
        }
        Preconditions.checkArgument(((long) hashSet.size()) == j, "Action name is duplicated");
        return generatePipelines;
    }
}
