package com.microsoft.azure.synapse.ml.cntk;

import com.microsoft.CNTK.CNTKExtensions$;
import com.microsoft.CNTK.CNTKLib;
import com.microsoft.CNTK.DataType;
import com.microsoft.CNTK.DeviceDescriptor;
import com.microsoft.CNTK.DoubleVectorVector;
import com.microsoft.CNTK.FloatVectorVector;
import com.microsoft.CNTK.Function;
import com.microsoft.CNTK.ParameterCloningMethod;
import com.microsoft.CNTK.SerializableFunction;
import com.microsoft.CNTK.UnorderedMapVariableValuePtr;
import com.microsoft.CNTK.Value;
import com.microsoft.CNTK.Variable;
import com.microsoft.CNTK.VariableVector;
import java.io.Serializable;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.Nil$;
import scala.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.util.Either;
import scala.util.Left;
import scala.util.Right;

/* compiled from: CNTKModel.scala */
/* loaded from: input_file:com/microsoft/azure/synapse/ml/cntk/CNTKModelUtils$.class */
public final class CNTKModelUtils$ implements Serializable {
    public static CNTKModelUtils$ MODULE$;

    static {
        new CNTKModelUtils$();
    }

    public List<Either<FloatVectorVector, DoubleVectorVector>> applyCNTKFunction(SerializableFunction serializableFunction, Map<Variable, Either<FloatVectorVector, DoubleVectorVector>> map, List<Variable> list, DeviceDescriptor deviceDescriptor) {
        Map map2 = (Map) map.map(tuple2 -> {
            Tuple2 $minus$greater$extension;
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            Variable variable = (Variable) tuple2._1();
            Left left = (Either) tuple2._2();
            if (left instanceof Left) {
                $minus$greater$extension = Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(variable), Value.createDenseFloat(variable.getShape(), (FloatVectorVector) left.value(), deviceDescriptor));
            } else {
                if (!(left instanceof Right)) {
                    throw new MatchError(left);
                }
                $minus$greater$extension = Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(variable), Value.createDenseDouble(variable.getShape(), (DoubleVectorVector) ((Right) left).value(), deviceDescriptor));
            }
            return $minus$greater$extension;
        }, Map$.MODULE$.canBuildFrom());
        UnorderedMapVariableValuePtr unorderedMapVariableValuePtr = new UnorderedMapVariableValuePtr();
        map2.foreach(tuple22 -> {
            $anonfun$applyCNTKFunction$2(unorderedMapVariableValuePtr, tuple22);
            return BoxedUnit.UNIT;
        });
        UnorderedMapVariableValuePtr unorderedMapVariableValuePtr2 = new UnorderedMapVariableValuePtr();
        list.foreach(variable -> {
            unorderedMapVariableValuePtr2.add(variable, null);
            return BoxedUnit.UNIT;
        });
        CNTKExtensions$.MODULE$.fromSerializable(serializableFunction).evaluate(unorderedMapVariableValuePtr, unorderedMapVariableValuePtr2, deviceDescriptor);
        List<Either<FloatVectorVector, DoubleVectorVector>> list2 = (List) list.map(variable2 -> {
            Left apply;
            DataType dataType = variable2.getDataType();
            DataType dataType2 = DataType.Float;
            if (dataType2 != null ? !dataType2.equals(dataType) : dataType != null) {
                DataType dataType3 = DataType.Double;
                if (dataType3 != null ? !dataType3.equals(dataType) : dataType != null) {
                    throw new MatchError(dataType);
                }
                DoubleVectorVector doubleVectorVector = new DoubleVectorVector();
                Value value = unorderedMapVariableValuePtr2.getitem(variable2);
                value.copyVariableValueToDouble(variable2, doubleVectorVector);
                value.delete();
                apply = package$.MODULE$.Right().apply(doubleVectorVector);
            } else {
                FloatVectorVector floatVectorVector = new FloatVectorVector();
                Value value2 = unorderedMapVariableValuePtr2.getitem(variable2);
                value2.copyVariableValueToFloat(variable2, floatVectorVector);
                value2.delete();
                apply = package$.MODULE$.Left().apply(floatVectorVector);
            }
            return apply;
        }, List$.MODULE$.canBuildFrom());
        map2.values().foreach(value -> {
            value.delete();
            return BoxedUnit.UNIT;
        });
        return list2;
    }

    private Map<Variable, Function1<Row, Either<Seq<Seq<Object>>, Seq<Seq<Object>>>>> makeInputExtractors(Map<Object, Variable> map) {
        return (Map) map.map(tuple2 -> {
            Function1 function1;
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            int _1$mcI$sp = tuple2._1$mcI$sp();
            Variable variable = (Variable) tuple2._2();
            Predef$ArrowAssoc$ predef$ArrowAssoc$ = Predef$ArrowAssoc$.MODULE$;
            Object ArrowAssoc = Predef$.MODULE$.ArrowAssoc(variable);
            DataType dataType = variable.getDataType();
            DataType dataType2 = DataType.Float;
            if (dataType2 != null ? !dataType2.equals(dataType) : dataType != null) {
                DataType dataType3 = DataType.Double;
                if (dataType3 != null ? !dataType3.equals(dataType) : dataType != null) {
                    throw new MatchError(dataType);
                }
                function1 = row -> {
                    return package$.MODULE$.Right().apply(row.getAs(_1$mcI$sp));
                };
            } else {
                function1 = row2 -> {
                    return package$.MODULE$.Left().apply(row2.getAs(_1$mcI$sp));
                };
            }
            return predef$ArrowAssoc$.$minus$greater$extension(ArrowAssoc, function1);
        }, Map$.MODULE$.canBuildFrom());
    }

    public Iterator<Row> applyModel(Map<String, Object> map, Broadcast<SerializableFunction> broadcast, Map<String, String> map2, Iterator<Row> iterator) {
        if (!iterator.hasNext()) {
            return package$.MODULE$.Iterator().apply(Nil$.MODULE$);
        }
        DeviceDescriptor useDefaultDevice = DeviceDescriptor.useDefaultDevice();
        Function clone = CNTKExtensions$.MODULE$.fromSerializable((SerializableFunction) broadcast.value()).clone(ParameterCloningMethod.Share);
        Map<Object, Variable> map3 = (Map) map.map(tuple2 -> {
            if (tuple2 == null) {
                throw new MatchError(tuple2);
            }
            String str = (String) tuple2._1();
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(BoxesRunTime.boxToInteger(tuple2._2$mcI$sp())), CNTKExtensions$.MODULE$.toSerializable(clone).getInputVar(str));
        }, Map$.MODULE$.canBuildFrom());
        Map map4 = (Map) map2.map(tuple22 -> {
            if (tuple22 == null) {
                throw new MatchError(tuple22);
            }
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(CNTKExtensions$.MODULE$.toSerializable(clone).getOutputVar((String) tuple22._2())), (String) tuple22._1());
        }, Map$.MODULE$.canBuildFrom());
        Map<Variable, Function1<Row, Either<Seq<Seq<Object>>, Seq<Seq<Object>>>>> makeInputExtractors = makeInputExtractors(map3);
        Map map5 = (Map) map3.map(tuple23 -> {
            Left apply;
            if (tuple23 == null) {
                throw new MatchError(tuple23);
            }
            Variable variable = (Variable) tuple23._2();
            Predef$ArrowAssoc$ predef$ArrowAssoc$ = Predef$ArrowAssoc$.MODULE$;
            Object ArrowAssoc = Predef$.MODULE$.ArrowAssoc(variable);
            DataType dataType = variable.getDataType();
            DataType dataType2 = DataType.Float;
            if (dataType2 != null ? !dataType2.equals(dataType) : dataType != null) {
                DataType dataType3 = DataType.Double;
                if (dataType3 != null ? !dataType3.equals(dataType) : dataType != null) {
                    throw new MatchError(dataType);
                }
                apply = package$.MODULE$.Right().apply(new DoubleVectorVector());
            } else {
                apply = package$.MODULE$.Left().apply(new FloatVectorVector());
            }
            return predef$ArrowAssoc$.$minus$greater$extension(ArrowAssoc, apply);
        }, Map$.MODULE$.canBuildFrom());
        Function1 function1 = row -> {
            return (Map) makeInputExtractors.map(tuple24 -> {
                if (tuple24 == null) {
                    throw new MatchError(tuple24);
                }
                Variable variable = (Variable) tuple24._1();
                return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc(variable), ConversionUtils$.MODULE$.toGVV((Either) ((Function1) tuple24._2()).apply(row), (Either) map5.apply(variable)));
            }, Map$.MODULE$.canBuildFrom());
        };
        List list = map4.keys().toList();
        VariableVector variableVector = new VariableVector();
        list.foreach(variable -> {
            variableVector.add(variable);
            return BoxedUnit.UNIT;
        });
        Function Combine = CNTKLib.Combine(variableVector);
        return iterator.map(row2 -> {
            List<Either<FloatVectorVector, DoubleVectorVector>> applyCNTKFunction = MODULE$.applyCNTKFunction(CNTKExtensions$.MODULE$.toSerializable(Combine), (Map) function1.apply(row2), list, useDefaultDevice);
            Row fromSeq = Row$.MODULE$.fromSeq((Seq) row2.toSeq().$plus$plus(Row$.MODULE$.apply((Seq) applyCNTKFunction.map(either -> {
                return ConversionUtils$.MODULE$.convertGVV(either);
            }, List$.MODULE$.canBuildFrom())).toSeq(), Seq$.MODULE$.canBuildFrom()));
            applyCNTKFunction.foreach(either2 -> {
                $anonfun$applyModel$9(either2);
                return BoxedUnit.UNIT;
            });
            return fromSeq;
        });
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ void $anonfun$applyCNTKFunction$2(UnorderedMapVariableValuePtr unorderedMapVariableValuePtr, Tuple2 tuple2) {
        if (tuple2 == null) {
            throw new MatchError(tuple2);
        }
        unorderedMapVariableValuePtr.add((Variable) tuple2._1(), (Value) tuple2._2());
        BoxedUnit boxedUnit = BoxedUnit.UNIT;
    }

    public static final /* synthetic */ void $anonfun$applyModel$9(Either either) {
        ConversionUtils$.MODULE$.deleteGVV(either);
    }

    private CNTKModelUtils$() {
        MODULE$ = this;
    }
}
