package org.datavec.local.transforms;

import com.codepoetics.protonpack.StreamUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.DataAction;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.join.Join;
import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.rank.CalculateSortedRank;
import org.datavec.api.transform.reduce.IAssociativeReducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.ConvertToSequence;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.comparator.WritableComparator;
import org.datavec.arrow.ArrowConverter;
import org.datavec.local.transforms.functions.EmptyRecordFunction;
import org.datavec.local.transforms.join.ExecuteJoinFromCoGroupFlatMapFunction;
import org.datavec.local.transforms.join.ExtractKeysFunction;
import org.datavec.local.transforms.misc.ColumnAsKeyPairFunction;
import org.datavec.local.transforms.rank.UnzipForCalculateSortedRankFunction;
import org.datavec.local.transforms.reduce.MapToPairForReducerFunction;
import org.datavec.local.transforms.sequence.ConvertToSequenceLengthOne;
import org.datavec.local.transforms.sequence.LocalGroupToSequenceFunction;
import org.datavec.local.transforms.sequence.LocalMapToPairByMultipleColumnsFunction;
import org.datavec.local.transforms.sequence.LocalSequenceFilterFunction;
import org.datavec.local.transforms.sequence.LocalSequenceTransformFunction;
import org.datavec.local.transforms.transform.LocalTransformFunction;
import org.datavec.local.transforms.transform.SequenceSplitFunction;
import org.datavec.local.transforms.transform.filter.LocalFilterFunction;
import org.nd4j.common.function.FunctionalUtils;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/datavec/local/transforms/LocalTransformExecutor.class */
public class LocalTransformExecutor {
    public static final String LOG_ERROR_PROPERTY = "org.datavec.spark.transform.logerrors";
    private static final Logger log = LoggerFactory.getLogger(LocalTransformExecutor.class);
    private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

    /* renamed from: org.datavec.local.transforms.LocalTransformExecutor$3, reason: invalid class name */
    /* loaded from: input_file:org/datavec/local/transforms/LocalTransformExecutor$3.class */
    static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$org$datavec$api$transform$ColumnType = new int[ColumnType.values().length];

        static {
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Double.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Float.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Integer.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Long.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.String.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Time.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    public static List<List<Writable>> execute(List<List<Writable>> list, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }
        List list2 = (List) list.parallelStream().filter(list3 -> {
            return list3.size() == transformProcess.getInitialSchema().numColumns();
        }).collect(Collectors.toList());
        if (list2.size() != list.size()) {
            log.warn("Filtered out " + (list.size() - list2.size()) + " values");
        }
        return (List) execute(list2, null, transformProcess).getFirst();
    }

    public static List<List<List<Writable>>> executeToSequence(List<List<Writable>> list, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            return (List) execute(list, null, transformProcess).getSecond();
        }
        throw new IllegalStateException("Cannot return non-sequence data with this method");
    }

    public static List<List<Writable>> executeSequenceToSeparate(List<List<List<Writable>>> list, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }
        return (List) execute(null, list, transformProcess).getFirst();
    }

    public static List<List<List<Writable>>> executeSequenceToSequence(List<List<List<Writable>>> list, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            return (List) execute(null, list, transformProcess).getSecond();
        }
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<List<List<Writable>>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(execute(it.next(), null, transformProcess).getFirst());
        }
        return arrayList;
    }

    public static List<List<String>> convertWritableInputToString(List<List<Writable>> list, Schema schema) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            List<Writable> list2 = list.get(i);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < list2.size(); i2++) {
                arrayList3.add(list2.get(i2).toString());
            }
            arrayList2.add(arrayList3);
        }
        return arrayList;
    }

    public static List<List<Writable>> convertStringInput(List<List<String>> list, Schema schema) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            List<String> list2 = list.get(i);
            ArrayList arrayList3 = new ArrayList();
            for (int i2 = 0; i2 < list2.size(); i2++) {
                switch (AnonymousClass3.$SwitchMap$org$datavec$api$transform$ColumnType[schema.getType(i2).ordinal()]) {
                    case 1:
                        arrayList3.add(new DoubleWritable(Double.parseDouble(list2.get(i2))));
                        break;
                    case 2:
                        arrayList3.add(new FloatWritable(Float.parseFloat(list2.get(i2))));
                        break;
                    case 3:
                        arrayList3.add(new IntWritable(Integer.parseInt(list2.get(i2))));
                        break;
                    case 4:
                        arrayList3.add(new LongWritable(Long.parseLong(list2.get(i2))));
                        break;
                    case 5:
                        arrayList3.add(new Text(list2.get(i2)));
                        break;
                    case 6:
                        arrayList3.add(new LongWritable(Long.parseLong(list2.get(i2))));
                        break;
                }
            }
            arrayList2.add(arrayList3);
        }
        return arrayList;
    }

    public static List<List<List<String>>> convertWritableInputToStringTimeSeries(List<List<List<Writable>>> list, Schema schema) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            List<List<Writable>> list2 = list.get(i);
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < list2.size(); i2++) {
                List<Writable> list3 = list2.get(i2);
                ArrayList arrayList3 = new ArrayList();
                for (int i3 = 0; i3 < list3.size(); i3++) {
                    arrayList3.add(list3.get(i3).toString());
                }
                arrayList2.add(arrayList3);
            }
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    public static List<List<List<Writable>>> convertStringInputTimeSeries(List<List<List<String>>> list, Schema schema) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            List<List<String>> list2 = list.get(i);
            ArrayList arrayList2 = new ArrayList();
            for (int i2 = 0; i2 < list2.size(); i2++) {
                List<String> list3 = list2.get(i2);
                ArrayList arrayList3 = new ArrayList();
                for (int i3 = 0; i3 < list3.size(); i3++) {
                    switch (AnonymousClass3.$SwitchMap$org$datavec$api$transform$ColumnType[schema.getType(i3).ordinal()]) {
                        case 1:
                            arrayList3.add(new DoubleWritable(Double.parseDouble(list3.get(i3))));
                            break;
                        case 2:
                            arrayList3.add(new FloatWritable(Float.parseFloat(list3.get(i3))));
                            break;
                        case 3:
                            arrayList3.add(new IntWritable(Integer.parseInt(list3.get(i3))));
                            break;
                        case 4:
                            arrayList3.add(new LongWritable(Long.parseLong(list3.get(i3))));
                            break;
                        case 5:
                            arrayList3.add(new Text(list3.get(i3)));
                            break;
                        case 6:
                            arrayList3.add(new LongWritable(Long.parseLong(list3.get(i3))));
                            break;
                    }
                }
                arrayList2.add(arrayList3);
            }
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    public static boolean isTryCatch() {
        return Boolean.getBoolean(LOG_ERROR_PROPERTY);
    }

    private static Pair<List<List<Writable>>, List<List<List<Writable>>>> execute(List<List<Writable>> list, List<List<List<Writable>>> list2, TransformProcess transformProcess) {
        List<List<Writable>> list3 = list;
        List<List<List<Writable>>> list4 = list2;
        List<DataAction> actionList = transformProcess.getActionList();
        if (list != null) {
            List<Writable> list5 = list.get(0);
            if (list5.size() != transformProcess.getInitialSchema().numColumns()) {
                throw new IllegalStateException("Input data number of columns (" + list5.size() + ") does not match the number of columns for the transform process (" + transformProcess.getInitialSchema().numColumns() + ")");
            }
        } else {
            List<List<Writable>> list6 = list2.get(0);
            if (list6.size() > 0 && list6.get(0).size() != transformProcess.getInitialSchema().numColumns()) {
                throw new IllegalStateException("Input sequence data number of columns (" + list6.get(0).size() + ") does not match the number of columns for the transform process (" + transformProcess.getInitialSchema().numColumns() + ")");
            }
        }
        for (DataAction dataAction : actionList) {
            if (dataAction.getTransform() != null) {
                Transform transform = dataAction.getTransform();
                if (list3 != null) {
                    LocalTransformFunction localTransformFunction = new LocalTransformFunction(transform);
                    list3 = isTryCatch() ? (List) list3.stream().map(list7 -> {
                        return (List) localTransformFunction.apply(list7);
                    }).filter(list8 -> {
                        return new EmptyRecordFunction().apply((List<Writable>) list8).booleanValue();
                    }).collect(Collectors.toList()) : (List) list3.stream().map(list9 -> {
                        return (List) localTransformFunction.apply(list9);
                    }).collect(Collectors.toList());
                } else {
                    LocalSequenceTransformFunction localSequenceTransformFunction = new LocalSequenceTransformFunction(transform);
                    list4 = isTryCatch() ? (List) list4.stream().map(list10 -> {
                        return (List) localSequenceTransformFunction.apply(list10);
                    }).filter(list11 -> {
                        return new SequenceEmptyRecordFunction().apply((List<List<Writable>>) list11).booleanValue();
                    }).collect(Collectors.toList()) : (List) list4.stream().map(list12 -> {
                        return (List) localSequenceTransformFunction.apply(list12);
                    }).collect(Collectors.toList());
                }
            } else if (dataAction.getFilter() != null) {
                Filter filter = dataAction.getFilter();
                if (list3 != null) {
                    LocalFilterFunction localFilterFunction = new LocalFilterFunction(filter);
                    list3 = (List) list3.stream().filter(list13 -> {
                        return localFilterFunction.apply((List<Writable>) list13).booleanValue();
                    }).collect(Collectors.toList());
                } else {
                    LocalSequenceFilterFunction localSequenceFilterFunction = new LocalSequenceFilterFunction(filter);
                    list4 = (List) list4.stream().filter(list14 -> {
                        return localSequenceFilterFunction.apply((List<List<Writable>>) list14).booleanValue();
                    }).collect(Collectors.toList());
                }
            } else if (dataAction.getConvertToSequence() != null) {
                ConvertToSequence convertToSequence = dataAction.getConvertToSequence();
                if (convertToSequence.isSingleStepSequencesMode()) {
                    ConvertToSequenceLengthOne convertToSequenceLengthOne = new ConvertToSequenceLengthOne();
                    list4 = (List) list3.stream().map(list15 -> {
                        return convertToSequenceLengthOne.apply((List<Writable>) list15);
                    }).collect(Collectors.toList());
                    list3 = null;
                } else {
                    LocalMapToPairByMultipleColumnsFunction localMapToPairByMultipleColumnsFunction = new LocalMapToPairByMultipleColumnsFunction(convertToSequence.getInputSchema().getIndexOfColumns(convertToSequence.getKeyColumns()));
                    Map groupByKey = FunctionalUtils.groupByKey((List) list3.stream().map(list16 -> {
                        return localMapToPairByMultipleColumnsFunction.apply((List<Writable>) list16);
                    }).collect(Collectors.toList()));
                    LocalGroupToSequenceFunction localGroupToSequenceFunction = new LocalGroupToSequenceFunction(convertToSequence.getComparator());
                    list4 = (List) groupByKey.entrySet().stream().map(entry -> {
                        return (List) entry.getValue();
                    }).map(list17 -> {
                        return localGroupToSequenceFunction.apply((List<List<Writable>>) list17);
                    }).collect(Collectors.toList());
                    list3 = null;
                }
            } else if (dataAction.getConvertFromSequence() != null) {
                if (list4 == null) {
                    throw new IllegalStateException("Cannot execute ConvertFromSequence operation: current sequence is null");
                }
                list3 = (List) list4.stream().flatMap(list18 -> {
                    return list18.stream();
                }).collect(Collectors.toList());
                list4 = null;
            } else if (dataAction.getSequenceSplit() != null) {
                SequenceSplit sequenceSplit = dataAction.getSequenceSplit();
                if (list4 == null) {
                    throw new IllegalStateException("Error during execution of SequenceSplit: currentSequence is null");
                }
                SequenceSplitFunction sequenceSplitFunction = new SequenceSplitFunction(sequenceSplit);
                list4 = (List) list4.stream().flatMap(list19 -> {
                    return sequenceSplitFunction.call(list19).stream();
                }).collect(Collectors.toList());
            } else if (dataAction.getReducer() != null) {
                IAssociativeReducer reducer = dataAction.getReducer();
                if (list3 == null) {
                    throw new IllegalStateException("Error during execution of reduction: current writables are null. Trying to execute a reduce operation on a sequence?");
                }
                MapToPairForReducerFunction mapToPairForReducerFunction = new MapToPairForReducerFunction(reducer);
                List list20 = (List) list3.stream().map(list21 -> {
                    return mapToPairForReducerFunction.apply((List<Writable>) list21);
                }).collect(Collectors.toList());
                HashMap hashMap = new HashMap();
                ((List) StreamUtils.aggregate(FunctionalUtils.groupByKey(list20).entrySet().stream(), new BiPredicate<Map.Entry<String, List<List<Writable>>>, Map.Entry<String, List<List<Writable>>>>() { // from class: org.datavec.local.transforms.LocalTransformExecutor.1
                    @Override // java.util.function.BiPredicate
                    public boolean test(Map.Entry<String, List<List<Writable>>> entry2, Map.Entry<String, List<List<Writable>>> entry3) {
                        return entry2.getKey().equals(entry3.getKey());
                    }
                }).collect(Collectors.toList())).stream().forEach(list22 -> {
                    Iterator it = list22.iterator();
                    while (it.hasNext()) {
                        Map.Entry entry2 = (Map.Entry) it.next();
                        if (!hashMap.containsKey(entry2.getKey())) {
                            IAggregableReduceOp aggregableReducer = reducer.aggregableReducer();
                            hashMap.put(entry2.getKey(), aggregableReducer);
                            Iterator it2 = ((List) entry2.getValue()).iterator();
                            while (it2.hasNext()) {
                                aggregableReducer.accept((List) it2.next());
                            }
                        }
                    }
                });
                list3 = (List) hashMap.entrySet().stream().map(entry2 -> {
                    return (List) ((IAggregableReduceOp) entry2.getValue()).get();
                }).collect(Collectors.toList());
            } else {
                if (dataAction.getCalculateSortedRank() == null) {
                    throw new RuntimeException("Unknown/not implemented action: " + dataAction);
                }
                CalculateSortedRank calculateSortedRank = dataAction.getCalculateSortedRank();
                if (list3 == null) {
                    throw new IllegalStateException("Error during execution of CalculateSortedRank: current writables are null. Trying to execute a CalculateSortedRank operation on a sequence? (not currently supported)");
                }
                final WritableComparator comparator = calculateSortedRank.getComparator();
                int indexOfColumn = calculateSortedRank.getInputSchema().getIndexOfColumn(calculateSortedRank.getSortOnColumn());
                final boolean isAscending = calculateSortedRank.isAscending();
                list3 = (List) ((List) StreamUtils.zipWithIndex(((List) ((List) list3.stream().map(list23 -> {
                    return new ColumnAsKeyPairFunction(indexOfColumn).apply((List<Writable>) list23);
                }).collect(Collectors.toList())).stream().sorted(new Comparator<Pair<Writable, List<Writable>>>() { // from class: org.datavec.local.transforms.LocalTransformExecutor.2
                    @Override // java.util.Comparator
                    public int compare(Pair<Writable, List<Writable>> pair, Pair<Writable, List<Writable>> pair2) {
                        int compare = comparator.compare(pair.getFirst(), pair2.getFirst());
                        return isAscending ? compare : -compare;
                    }
                }).collect(Collectors.toList())).stream()).collect(Collectors.toList())).stream().map(indexed -> {
                    return new UnzipForCalculateSortedRankFunction().apply(Pair.of(indexed.getValue(), Long.valueOf(indexed.getIndex())));
                }).collect(Collectors.toList());
            }
        }
        if (list4 == null) {
            return new Pair<>(ArrowConverter.toArrowWritables(ArrowConverter.toArrowColumns(bufferAllocator, transformProcess.getFinalSchema(), list3), transformProcess.getFinalSchema()), (Object) null);
        }
        boolean z = true;
        Integer num = null;
        for (List<List<Writable>> list24 : list4) {
            if (num == null) {
                num = Integer.valueOf(list24.size());
            } else if (list24.size() != num.intValue()) {
                z = false;
            }
        }
        if (z) {
            list4 = ArrowConverter.toArrowWritablesTimeSeries(ArrowConverter.toArrowColumnsTimeSeries(bufferAllocator, transformProcess.getFinalSchema(), list4), transformProcess.getFinalSchema(), list4.get(0).size() * list4.get(0).get(0).size());
        }
        return Pair.of((Object) null, list4);
    }

    public static List<List<Writable>> executeJoin(Join join, List<List<Writable>> list, List<List<Writable>> list2) {
        String[] joinColumnsLeft = join.getJoinColumnsLeft();
        int[] iArr = new int[joinColumnsLeft.length];
        for (int i = 0; i < joinColumnsLeft.length; i++) {
            iArr[i] = join.getLeftSchema().getIndexOfColumn(joinColumnsLeft[i]);
        }
        ExtractKeysFunction extractKeysFunction = new ExtractKeysFunction(iArr);
        List list3 = (List) list.stream().filter(list4 -> {
            return list4.size() != joinColumnsLeft.length;
        }).map(list5 -> {
            return extractKeysFunction.apply((List<Writable>) list5);
        }).collect(Collectors.toList());
        String[] joinColumnsRight = join.getJoinColumnsRight();
        int[] iArr2 = new int[joinColumnsRight.length];
        for (int i2 = 0; i2 < joinColumnsRight.length; i2++) {
            iArr2[i2] = join.getRightSchema().getIndexOfColumn(joinColumnsRight[i2]);
        }
        ExtractKeysFunction extractKeysFunction2 = new ExtractKeysFunction(iArr2);
        Map cogroup = FunctionalUtils.cogroup(list3, (List) list2.stream().filter(list6 -> {
            return list6.size() != joinColumnsRight.length;
        }).map(list7 -> {
            return extractKeysFunction2.apply((List<Writable>) list7);
        }).collect(Collectors.toList()));
        ExecuteJoinFromCoGroupFlatMapFunction executeJoinFromCoGroupFlatMapFunction = new ExecuteJoinFromCoGroupFlatMapFunction(join);
        List list8 = (List) cogroup.entrySet().stream().flatMap(entry -> {
            return executeJoinFromCoGroupFlatMapFunction.call(Pair.of(entry.getKey(), entry.getValue())).stream();
        }).collect(Collectors.toList());
        Schema outputSchema = join.getOutputSchema();
        return ArrowConverter.toArrowWritables(ArrowConverter.toArrowColumns(bufferAllocator, outputSchema, list8), outputSchema);
    }
}
