package org.apache.flink.python.chain;

import java.lang.reflect.Field;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.flink.api.common.operators.SlotSharingGroup;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.python.PythonOptions;
import org.apache.flink.python.util.PythonConfigUtil;
import org.apache.flink.shaded.guava32.com.google.common.collect.Lists;
import org.apache.flink.shaded.guava32.com.google.common.collect.Queues;
import org.apache.flink.shaded.guava32.com.google.common.collect.Sets;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.SourceOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.api.operators.python.DataStreamPythonFunctionOperator;
import org.apache.flink.streaming.api.operators.python.embedded.AbstractEmbeddedDataStreamPythonFunctionOperator;
import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonCoProcessOperator;
import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonKeyedCoProcessOperator;
import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonKeyedProcessOperator;
import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonProcessOperator;
import org.apache.flink.streaming.api.operators.python.embedded.EmbeddedPythonWindowOperator;
import org.apache.flink.streaming.api.operators.python.process.AbstractExternalDataStreamPythonFunctionOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonCoProcessOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonKeyedCoProcessOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonKeyedProcessOperator;
import org.apache.flink.streaming.api.operators.python.process.ExternalPythonProcessOperator;
import org.apache.flink.streaming.api.transformations.AbstractBroadcastStateTransformation;
import org.apache.flink.streaming.api.transformations.AbstractMultipleInputTransformation;
import org.apache.flink.streaming.api.transformations.LegacySinkTransformation;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
import org.apache.flink.streaming.api.transformations.ReduceTransformation;
import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
import org.apache.flink.streaming.api.transformations.SinkTransformation;
import org.apache.flink.streaming.api.transformations.TimestampsAndWatermarksTransformation;
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
import org.apache.flink.streaming.api.transformations.UnionTransformation;
import org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
import org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;

/* loaded from: input_file:org/apache/flink/python/chain/PythonOperatorChainingOptimizer.class */
public class PythonOperatorChainingOptimizer {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.apache.flink.python.chain.PythonOperatorChainingOptimizer$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/flink/python/chain/PythonOperatorChainingOptimizer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$flink$streaming$api$operators$ChainingStrategy = new int[ChainingStrategy.values().length];

        static {
            try {
                $SwitchMap$org$apache$flink$streaming$api$operators$ChainingStrategy[ChainingStrategy.NEVER.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$flink$streaming$api$operators$ChainingStrategy[ChainingStrategy.ALWAYS.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$flink$streaming$api$operators$ChainingStrategy[ChainingStrategy.HEAD.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$flink$streaming$api$operators$ChainingStrategy[ChainingStrategy.HEAD_WITH_SOURCES.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/python/chain/PythonOperatorChainingOptimizer$ChainInfo.class */
    public static class ChainInfo {
        public final Transformation<?> newTransformation;
        public final Collection<Transformation<?>> oldTransformations;

        private ChainInfo(Transformation<?> transformation, Collection<Transformation<?>> collection) {
            this.newTransformation = transformation;
            this.oldTransformations = collection;
        }

        public static ChainInfo of(Transformation<?> transformation) {
            return new ChainInfo(transformation, Collections.emptyList());
        }

        public static ChainInfo of(Transformation<?> transformation, Collection<Transformation<?>> collection) {
            return new ChainInfo(transformation, collection);
        }
    }

    public static void apply(StreamExecutionEnvironment streamExecutionEnvironment) throws Exception {
        if (((Boolean) streamExecutionEnvironment.getConfiguration().get(PythonOptions.PYTHON_OPERATOR_CHAINING_ENABLED)).booleanValue()) {
            Field declaredField = StreamExecutionEnvironment.class.getDeclaredField("transformations");
            declaredField.setAccessible(true);
            declaredField.set(streamExecutionEnvironment, optimize((List) declaredField.get(streamExecutionEnvironment)));
        }
    }

    public static Transformation<?> apply(StreamExecutionEnvironment streamExecutionEnvironment, Transformation<?> transformation) throws Exception {
        if (!((Boolean) streamExecutionEnvironment.getConfiguration().get(PythonOptions.PYTHON_OPERATOR_CHAINING_ENABLED)).booleanValue()) {
            return transformation;
        }
        Field declaredField = StreamExecutionEnvironment.class.getDeclaredField("transformations");
        declaredField.setAccessible(true);
        Tuple2<List<Transformation<?>>, Transformation<?>> optimize = optimize((List) declaredField.get(streamExecutionEnvironment), transformation);
        declaredField.set(streamExecutionEnvironment, optimize.f0);
        return (Transformation) optimize.f1;
    }

    public static List<Transformation<?>> optimize(List<Transformation<?>> list) {
        Map<Transformation<?>, Set<Transformation<?>>> buildOutputMap = buildOutputMap(list);
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Set newIdentityHashSet = Sets.newIdentityHashSet();
        ArrayDeque newArrayDeque = Queues.newArrayDeque(list);
        while (!newArrayDeque.isEmpty()) {
            Transformation transformation = (Transformation) newArrayDeque.poll();
            if (!newIdentityHashSet.contains(transformation)) {
                newIdentityHashSet.add(transformation);
                ChainInfo chainWithInputIfPossible = chainWithInputIfPossible(transformation, buildOutputMap);
                linkedHashSet.add(chainWithInputIfPossible.newTransformation);
                linkedHashSet.removeAll(chainWithInputIfPossible.oldTransformations);
                newIdentityHashSet.addAll(chainWithInputIfPossible.oldTransformations);
                newArrayDeque.add(chainWithInputIfPossible.newTransformation);
                newArrayDeque.addAll(chainWithInputIfPossible.newTransformation.getInputs());
            }
        }
        return new ArrayList(linkedHashSet);
    }

    public static Tuple2<List<Transformation<?>>, Transformation<?>> optimize(List<Transformation<?>> list, Transformation<?> transformation) {
        Map<Transformation<?>, Set<Transformation<?>>> buildOutputMap = buildOutputMap(list);
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Set newIdentityHashSet = Sets.newIdentityHashSet();
        ArrayDeque newArrayDeque = Queues.newArrayDeque();
        newArrayDeque.add(transformation);
        while (!newArrayDeque.isEmpty()) {
            Transformation<?> transformation2 = (Transformation) newArrayDeque.poll();
            if (!newIdentityHashSet.contains(transformation2)) {
                newIdentityHashSet.add(transformation2);
                ChainInfo chainWithInputIfPossible = chainWithInputIfPossible(transformation2, buildOutputMap);
                linkedHashSet.add(chainWithInputIfPossible.newTransformation);
                linkedHashSet.removeAll(chainWithInputIfPossible.oldTransformations);
                newIdentityHashSet.addAll(chainWithInputIfPossible.oldTransformations);
                newArrayDeque.add(chainWithInputIfPossible.newTransformation);
                newArrayDeque.addAll(chainWithInputIfPossible.newTransformation.getInputs());
                if (transformation2 == transformation) {
                    transformation = chainWithInputIfPossible.newTransformation;
                }
            }
        }
        return Tuple2.of(new ArrayList(linkedHashSet), transformation);
    }

    private static Map<Transformation<?>, Set<Transformation<?>>> buildOutputMap(List<Transformation<?>> list) {
        HashMap hashMap = new HashMap();
        ArrayDeque newArrayDeque = Queues.newArrayDeque(list);
        Set newIdentityHashSet = Sets.newIdentityHashSet();
        while (!newArrayDeque.isEmpty()) {
            Transformation transformation = (Transformation) newArrayDeque.poll();
            if (!newIdentityHashSet.contains(transformation)) {
                newIdentityHashSet.add(transformation);
                Iterator it = transformation.getInputs().iterator();
                while (it.hasNext()) {
                    ((Set) hashMap.computeIfAbsent((Transformation) it.next(), transformation2 -> {
                        return Sets.newHashSet();
                    })).add(transformation);
                }
                newArrayDeque.addAll(transformation.getInputs());
            }
        }
        return hashMap;
    }

    private static ChainInfo chainWithInputIfPossible(Transformation<?> transformation, Map<Transformation<?>, Set<Transformation<?>>> map) {
        ChainInfo chainInfo = null;
        if ((transformation instanceof OneInputTransformation) && PythonConfigUtil.isPythonDataStreamOperator(transformation)) {
            Object obj = transformation.getInputs().get(0);
            while (true) {
                PartitionTransformation partitionTransformation = (Transformation) obj;
                if (PythonConfigUtil.isPythonDataStreamOperator((Transformation<?>) partitionTransformation)) {
                    if (isChainable(partitionTransformation, transformation, map)) {
                        Transformation<?> createChainedTransformation = createChainedTransformation(partitionTransformation, transformation);
                        Set<Transformation<?>> set = map.get(transformation);
                        if (set != null) {
                            Iterator<Transformation<?>> it = set.iterator();
                            while (it.hasNext()) {
                                replaceInput(it.next(), transformation, createChainedTransformation);
                            }
                            map.put(createChainedTransformation, set);
                        }
                        chainInfo = ChainInfo.of(createChainedTransformation, Arrays.asList(partitionTransformation, transformation));
                    }
                } else {
                    if (!(partitionTransformation instanceof PartitionTransformation) || !(partitionTransformation.getPartitioner() instanceof ForwardPartitioner)) {
                        break;
                    }
                    obj = partitionTransformation.getInputs().get(0);
                }
            }
            return ChainInfo.of(transformation);
        }
        if (chainInfo == null) {
            chainInfo = ChainInfo.of(transformation);
        }
        return chainInfo;
    }

    private static Transformation<?> createChainedTransformation(Transformation<?> transformation, Transformation<?> transformation2) {
        DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo;
        OneInputTransformation twoInputTransformation;
        DataStreamPythonFunctionOperator operator = PythonConfigUtil.getOperatorFactory(transformation).getOperator();
        DataStreamPythonFunctionOperator operator2 = PythonConfigUtil.getOperatorFactory(transformation2).getOperator();
        if (!$assertionsDisabled && !arePythonOperatorsInSameExecutionEnvironment(operator, operator2)) {
            throw new AssertionError();
        }
        DataStreamPythonFunctionInfo copy = operator.getPythonFunctionInfo().copy();
        DataStreamPythonFunctionInfo copy2 = operator2.getPythonFunctionInfo().copy();
        DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo2 = copy2;
        while (true) {
            dataStreamPythonFunctionInfo = dataStreamPythonFunctionInfo2;
            if (dataStreamPythonFunctionInfo.getInputs().length == 0) {
                break;
            }
            dataStreamPythonFunctionInfo2 = (DataStreamPythonFunctionInfo) dataStreamPythonFunctionInfo.getInputs()[0];
        }
        dataStreamPythonFunctionInfo.setInputs(new DataStreamPythonFunctionInfo[]{copy});
        OneInputStreamOperator copy3 = operator.copy(copy2, operator2.getProducedType());
        copy3.addSideOutputTags(operator2.getSideOutputTags());
        if (operator instanceof OneInputStreamOperator) {
            twoInputTransformation = new OneInputTransformation((Transformation) transformation.getInputs().get(0), transformation.getName() + ", " + transformation2.getName(), copy3, transformation2.getOutputType(), transformation.getParallelism(), false);
            twoInputTransformation.setStateKeySelector(((OneInputTransformation) transformation).getStateKeySelector());
            twoInputTransformation.setStateKeyType(((OneInputTransformation) transformation).getStateKeyType());
        } else {
            twoInputTransformation = new TwoInputTransformation((Transformation) transformation.getInputs().get(0), (Transformation) transformation.getInputs().get(1), transformation.getName() + ", " + transformation2.getName(), (TwoInputStreamOperator) copy3, transformation2.getOutputType(), transformation.getParallelism(), false);
            ((TwoInputTransformation) twoInputTransformation).setStateKeySelectors(((TwoInputTransformation) transformation).getStateKeySelector1(), ((TwoInputTransformation) transformation).getStateKeySelector2());
            ((TwoInputTransformation) twoInputTransformation).setStateKeyType(((TwoInputTransformation) transformation).getStateKeyType());
        }
        twoInputTransformation.setUid(transformation.getUid());
        if (transformation.getUserProvidedNodeHash() != null) {
            twoInputTransformation.setUidHash(transformation.getUserProvidedNodeHash());
        }
        Iterator it = transformation.getManagedMemorySlotScopeUseCases().iterator();
        while (it.hasNext()) {
            twoInputTransformation.declareManagedMemoryUseCaseAtSlotScope((ManagedMemoryUseCase) it.next());
        }
        Iterator it2 = transformation2.getManagedMemorySlotScopeUseCases().iterator();
        while (it2.hasNext()) {
            twoInputTransformation.declareManagedMemoryUseCaseAtSlotScope((ManagedMemoryUseCase) it2.next());
        }
        for (Map.Entry entry : transformation.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
            twoInputTransformation.declareManagedMemoryUseCaseAtOperatorScope((ManagedMemoryUseCase) entry.getKey(), ((Integer) entry.getValue()).intValue());
        }
        for (Map.Entry entry2 : transformation2.getManagedMemoryOperatorScopeUseCaseWeights().entrySet()) {
            twoInputTransformation.declareManagedMemoryUseCaseAtOperatorScope((ManagedMemoryUseCase) entry2.getKey(), ((Integer) entry2.getValue()).intValue() + ((Integer) twoInputTransformation.getManagedMemoryOperatorScopeUseCaseWeights().getOrDefault(entry2.getKey(), 0)).intValue());
        }
        twoInputTransformation.setBufferTimeout(Math.min(transformation.getBufferTimeout(), transformation2.getBufferTimeout()));
        if (transformation.getMaxParallelism() > 0) {
            twoInputTransformation.setMaxParallelism(transformation.getMaxParallelism());
        }
        twoInputTransformation.setChainingStrategy(PythonConfigUtil.getOperatorFactory(transformation).getChainingStrategy());
        twoInputTransformation.setCoLocationGroupKey(transformation.getCoLocationGroupKey());
        twoInputTransformation.setResources(transformation.getMinResources().merge(transformation2.getMinResources()), transformation.getPreferredResources().merge(transformation2.getPreferredResources()));
        if (transformation.getSlotSharingGroup().isPresent()) {
            twoInputTransformation.setSlotSharingGroup((SlotSharingGroup) transformation.getSlotSharingGroup().get());
        }
        if (transformation.getDescription() != null && transformation2.getDescription() != null) {
            twoInputTransformation.setDescription(transformation.getDescription() + ", " + transformation2.getDescription());
        } else if (transformation.getDescription() != null) {
            twoInputTransformation.setDescription(transformation.getDescription());
        } else if (transformation2.getDescription() != null) {
            twoInputTransformation.setDescription(transformation2.getDescription());
        }
        return twoInputTransformation;
    }

    private static boolean isChainable(Transformation<?> transformation, Transformation<?> transformation2, Map<Transformation<?>, Set<Transformation<?>>> map) {
        return transformation.getParallelism() == transformation2.getParallelism() && transformation.getMaxParallelism() == transformation2.getMaxParallelism() && transformation.getSlotSharingGroup().equals(transformation2.getSlotSharingGroup()) && areOperatorsChainable(transformation, transformation2) && map.get(transformation).size() == 1;
    }

    private static boolean areOperatorsChainable(Transformation<?> transformation, Transformation<?> transformation2) {
        if (!areOperatorsChainableByChainingStrategy(transformation, transformation2) || (transformation instanceof PythonBroadcastStateTransformation) || (transformation instanceof PythonKeyedBroadcastStateTransformation)) {
            return false;
        }
        DataStreamPythonFunctionOperator operator = PythonConfigUtil.getOperatorFactory(transformation).getOperator();
        DataStreamPythonFunctionOperator operator2 = PythonConfigUtil.getOperatorFactory(transformation2).getOperator();
        if (arePythonOperatorsInSameExecutionEnvironment(operator, operator2)) {
            return ((operator2 instanceof ExternalPythonProcessOperator) && ((operator instanceof ExternalPythonKeyedProcessOperator) || (operator instanceof ExternalPythonKeyedCoProcessOperator) || (operator instanceof ExternalPythonProcessOperator) || (operator instanceof ExternalPythonCoProcessOperator))) || ((operator2 instanceof EmbeddedPythonProcessOperator) && ((operator instanceof EmbeddedPythonKeyedProcessOperator) || (operator instanceof EmbeddedPythonKeyedCoProcessOperator) || (operator instanceof EmbeddedPythonProcessOperator) || (operator instanceof EmbeddedPythonCoProcessOperator) || (operator instanceof EmbeddedPythonWindowOperator)));
        }
        return false;
    }

    private static boolean arePythonOperatorsInSameExecutionEnvironment(DataStreamPythonFunctionOperator<?> dataStreamPythonFunctionOperator, DataStreamPythonFunctionOperator<?> dataStreamPythonFunctionOperator2) {
        return ((dataStreamPythonFunctionOperator instanceof AbstractExternalDataStreamPythonFunctionOperator) && (dataStreamPythonFunctionOperator2 instanceof AbstractExternalDataStreamPythonFunctionOperator)) || ((dataStreamPythonFunctionOperator instanceof AbstractEmbeddedDataStreamPythonFunctionOperator) && (dataStreamPythonFunctionOperator2 instanceof AbstractEmbeddedDataStreamPythonFunctionOperator));
    }

    private static boolean areOperatorsChainableByChainingStrategy(Transformation<?> transformation, Transformation<?> transformation2) {
        boolean z;
        StreamOperatorFactory<?> operatorFactory = PythonConfigUtil.getOperatorFactory(transformation);
        StreamOperatorFactory<?> operatorFactory2 = PythonConfigUtil.getOperatorFactory(transformation2);
        switch (AnonymousClass1.$SwitchMap$org$apache$flink$streaming$api$operators$ChainingStrategy[operatorFactory.getChainingStrategy().ordinal()]) {
            case 1:
                z = false;
                break;
            case 2:
            case 3:
            case 4:
                z = true;
                break;
            default:
                throw new RuntimeException("Unknown chaining strategy: " + operatorFactory.getChainingStrategy());
        }
        switch (AnonymousClass1.$SwitchMap$org$apache$flink$streaming$api$operators$ChainingStrategy[operatorFactory2.getChainingStrategy().ordinal()]) {
            case 1:
            case 3:
                z = false;
                break;
            case 2:
                break;
            case 4:
                z &= operatorFactory instanceof SourceOperatorFactory;
                break;
            default:
                throw new RuntimeException("Unknown chaining strategy: " + operatorFactory.getChainingStrategy());
        }
        return z;
    }

    private static void replaceInput(Transformation<?> transformation, Transformation<?> transformation2, Transformation<?> transformation3) {
        try {
            if ((transformation instanceof OneInputTransformation) || (transformation instanceof SideOutputTransformation) || (transformation instanceof ReduceTransformation) || (transformation instanceof LegacySinkTransformation) || (transformation instanceof TimestampsAndWatermarksTransformation) || (transformation instanceof PartitionTransformation)) {
                Field declaredField = transformation.getClass().getDeclaredField("input");
                declaredField.setAccessible(true);
                declaredField.set(transformation, transformation3);
            } else if (transformation instanceof SinkTransformation) {
                Field declaredField2 = transformation.getClass().getDeclaredField("input");
                declaredField2.setAccessible(true);
                declaredField2.set(transformation, transformation3);
                Field declaredField3 = DataStream.class.getDeclaredField("transformation");
                declaredField3.setAccessible(true);
                declaredField3.set(((SinkTransformation) transformation).getInputStream(), transformation3);
            } else if (transformation instanceof TwoInputTransformation) {
                Field declaredField4 = ((TwoInputTransformation) transformation).getInput1() == transformation2 ? transformation.getClass().getDeclaredField("input1") : transformation.getClass().getDeclaredField("input2");
                declaredField4.setAccessible(true);
                declaredField4.set(transformation, transformation3);
            } else if ((transformation instanceof UnionTransformation) || (transformation instanceof AbstractMultipleInputTransformation)) {
                Field declaredField5 = transformation.getClass().getDeclaredField("inputs");
                declaredField5.setAccessible(true);
                ArrayList newArrayList = Lists.newArrayList();
                newArrayList.addAll(transformation.getInputs());
                newArrayList.remove(transformation2);
                newArrayList.add(transformation3);
                declaredField5.set(transformation, newArrayList);
            } else {
                if (!(transformation instanceof AbstractBroadcastStateTransformation)) {
                    throw new RuntimeException("Unsupported transformation: " + transformation);
                }
                Field declaredField6 = ((AbstractBroadcastStateTransformation) transformation).getRegularInput() == transformation2 ? transformation.getClass().getDeclaredField("regularInput") : transformation.getClass().getDeclaredField("broadcastInput");
                declaredField6.setAccessible(true);
                declaredField6.set(transformation, transformation3);
            }
        } catch (IllegalAccessException | NoSuchFieldException e) {
            throw new RuntimeException(e);
        }
    }

    static {
        $assertionsDisabled = !PythonOperatorChainingOptimizer.class.desiredAssertionStatus();
    }
}
