package org.tensorframes.dsl;

import com.typesafe.scalalogging.slf4j.LazyLogging;
import com.typesafe.scalalogging.slf4j.Logger;
import javax.annotation.Nullable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.NumericType;
import org.apache.spark.sql.types.StructField;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.TensorShapeProto;
import org.tensorframes.ColumnInformation$;
import org.tensorframes.Logging;
import org.tensorframes.Shape;
import org.tensorframes.Shape$;
import org.tensorframes.SparkTFColInfo;
import org.tensorframes.dsl.DefaultConversions;
import org.tensorframes.dsl.DslImpl;
import org.tensorframes.impl.DenseTensor;
import org.tensorframes.impl.DenseTensor$;
import org.tensorframes.impl.SupportedOperations$;
import scala.Function1;
import scala.Option$;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Iterable$;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.SeqLike;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.mutable.StringBuilder;
import scala.math.Numeric;
import scala.math.Numeric$IntIsIntegral$;
import scala.reflect.api.TypeTags;
import scala.runtime.BoxedUnit;
import scala.runtime.ObjectRef;

/* compiled from: DslImpl.scala */
/* loaded from: input_file:org/tensorframes/dsl/DslImpl$.class */
public final class DslImpl$ implements Logging, DefaultConversions {
    public static final DslImpl$ MODULE$ = null;
    private final long org$tensorframes$dsl$DslImpl$$U;
    private final Logger logger;
    private volatile DefaultConversions$DoubleConversion$ DoubleConversion$module;
    private volatile DefaultConversions$FloatConversion$ FloatConversion$module;
    private volatile DefaultConversions$IntConversion$ IntConversion$module;
    private volatile boolean bitmap$0;

    static {
        new DslImpl$();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v5 */
    private DefaultConversions$DoubleConversion$ DoubleConversion$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (this.DoubleConversion$module == null) {
                this.DoubleConversion$module = new DefaultConversions$DoubleConversion$(this);
            }
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            r0 = r0;
            return this.DoubleConversion$module;
        }
    }

    @Override // org.tensorframes.dsl.DefaultConversions
    public DefaultConversions$DoubleConversion$ DoubleConversion() {
        return this.DoubleConversion$module == null ? DoubleConversion$lzycompute() : this.DoubleConversion$module;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v5 */
    private DefaultConversions$FloatConversion$ FloatConversion$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (this.FloatConversion$module == null) {
                this.FloatConversion$module = new DefaultConversions$FloatConversion$(this);
            }
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            r0 = r0;
            return this.FloatConversion$module;
        }
    }

    @Override // org.tensorframes.dsl.DefaultConversions
    public DefaultConversions$FloatConversion$ FloatConversion() {
        return this.FloatConversion$module == null ? FloatConversion$lzycompute() : this.FloatConversion$module;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v5 */
    private DefaultConversions$IntConversion$ IntConversion$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (this.IntConversion$module == null) {
                this.IntConversion$module = new DefaultConversions$IntConversion$(this);
            }
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            r0 = r0;
            return this.IntConversion$module;
        }
    }

    @Override // org.tensorframes.dsl.DefaultConversions
    public DefaultConversions$IntConversion$ IntConversion() {
        return this.IntConversion$module == null ? IntConversion$lzycompute() : this.IntConversion$module;
    }

    @Override // org.tensorframes.dsl.DefaultConversions
    public <T> ConvertibleToDenseTensor<Seq<T>> sequenceTensor1Conversion(Numeric<T> numeric, TypeTags.TypeTag<T> typeTag, ConvertibleToDenseTensor<T> convertibleToDenseTensor) {
        return DefaultConversions.Cclass.sequenceTensor1Conversion(this, numeric, typeTag, convertibleToDenseTensor);
    }

    @Override // org.tensorframes.dsl.DefaultConversions
    public <T> ConvertibleToDenseTensor<Seq<Seq<T>>> sequenceTensor2Conversion(Numeric<T> numeric, TypeTags.TypeTag<T> typeTag, ConvertibleToDenseTensor<T> convertibleToDenseTensor) {
        return DefaultConversions.Cclass.sequenceTensor2Conversion(this, numeric, typeTag, convertibleToDenseTensor);
    }

    @Override // org.tensorframes.Logging
    public void logDebug(String str) {
        Logging.Cclass.logDebug(this, str);
    }

    @Override // org.tensorframes.Logging
    public void logInfo(String str) {
        Logging.Cclass.logInfo(this, str);
    }

    @Override // org.tensorframes.Logging
    public void logTrace(String str) {
        Logging.Cclass.logTrace(this, str);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v5 */
    private Logger logger$lzycompute() {
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                this.logger = LazyLogging.class.logger(this);
                this.bitmap$0 = true;
            }
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            r0 = r0;
            return this.logger;
        }
    }

    /* renamed from: logger, reason: merged with bridge method [inline-methods] */
    public Logger m128logger() {
        return this.bitmap$0 ? this.logger : logger$lzycompute();
    }

    public DslImpl.ShapeToAttr ShapeToAttr(Shape shape) {
        return new DslImpl.ShapeToAttr(shape);
    }

    public DslImpl.SQLTypeToAttr SQLTypeToAttr(NumericType numericType) {
        return new DslImpl.SQLTypeToAttr(numericType);
    }

    public DslImpl.DataTypeToAttr DataTypeToAttr(DataType dataType) {
        return new DslImpl.DataTypeToAttr(dataType);
    }

    public TensorShapeProto buildShape(Shape shape) {
        return shape.toProto();
    }

    public AttrValue org$tensorframes$dsl$DslImpl$$buildType(NumericType numericType) {
        return AttrValue.newBuilder().setType(ProtoConversions$.MODULE$.getDType(numericType)).build();
    }

    public GraphDef buildGraph(Seq<Node> seq) {
        logTrace("buildGraph: freezing nodes");
        seq.foreach(new DslImpl$$anonfun$buildGraph$1());
        logTrace("buildGraph: Freezing everything");
        seq.foreach(new DslImpl$$anonfun$buildGraph$2());
        logTrace(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"buildGraph for nodes: ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{seq.map(new DslImpl$$anonfun$buildGraph$3(), Seq$.MODULE$.canBuildFrom())})));
        ObjectRef create = ObjectRef.create(Predef$.MODULE$.Map().empty());
        seq.foreach(new DslImpl$$anonfun$buildGraph$4(create));
        GraphDef.Builder newBuilder = GraphDef.newBuilder();
        ((IterableLike) ((Map) create.elem).values().flatMap(new DslImpl$$anonfun$buildGraph$5(), Iterable$.MODULE$.canBuildFrom())).foreach(new DslImpl$$anonfun$buildGraph$6(newBuilder));
        return newBuilder.build();
    }

    public GraphDef buildGraph(Node node, Seq<Node> seq) {
        return buildGraph((Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Node[]{node})).$plus$plus(seq, Seq$.MODULE$.canBuildFrom()));
    }

    public Map<String, Node> org$tensorframes$dsl$DslImpl$$getClosure(Node node, Map<String, Node> map) {
        logTrace(new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"closure: n=", ", parents=", ","})).s(Predef$.MODULE$.genericWrapArray(new Object[]{node.name(), node.parents().map(new DslImpl$$anonfun$org$tensorframes$dsl$DslImpl$$getClosure$1(), Seq$.MODULE$.canBuildFrom())}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{" treated=", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{map.keySet()}))).toString());
        return uniqueByName((Seq) ((SeqLike) ((TraversableOnce) ((TraversableLike) node.parents().filterNot(new DslImpl$$anonfun$1(map))).flatMap(new DslImpl$$anonfun$2(node, map), Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()).values().toSeq().$plus$plus(map.values().toSeq(), Seq$.MODULE$.canBuildFrom())).$plus$colon(node, Seq$.MODULE$.canBuildFrom()));
    }

    private Map<String, Node> uniqueByName(Seq<Node> seq) {
        return seq.groupBy(new DslImpl$$anonfun$uniqueByName$1()).mapValues(new DslImpl$$anonfun$uniqueByName$2());
    }

    public Node build_constant(DenseTensor denseTensor) {
        return build("Const", build$default$2(), build$default$3(), build$default$4(), false, (NumericType) SupportedOperations$.MODULE$.opsFor(denseTensor.dtype()).mo181sqlType(), denseTensor.shape(), build$default$8(), build$default$9(), (Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("value"), AttrValue.newBuilder().setTensor(DenseTensor$.MODULE$.toTensorProto(denseTensor)).build())})));
    }

    public Node placeholder(NumericType numericType, Shape shape) {
        return build("Placeholder", build$default$2(), build$default$3(), build$default$4(), false, numericType, shape, build$default$8(), build$default$9(), (Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("shape"), ShapeToAttr(shape).toAttr())})));
    }

    public Operation extractPlaceholder(Dataset<Row> dataset, String str, String str2, boolean z) {
        StructField structField = (StructField) dataset.schema().find(new DslImpl$$anonfun$3(str)).getOrElse(new DslImpl$$anonfun$4(dataset, str));
        SparkTFColInfo sparkTFColInfo = (SparkTFColInfo) ColumnInformation$.MODULE$.apply(structField).stf().getOrElse(new DslImpl$$anonfun$5(str, structField));
        return placeholder((NumericType) SupportedOperations$.MODULE$.opsFor(sparkTFColInfo.dataType()).mo181sqlType(), z ? sparkTFColInfo.shape() : sparkTFColInfo.shape().tail()).named(str2);
    }

    public Shape org$tensorframes$dsl$DslImpl$$commonShape(Seq<Shape> seq) {
        Predef$.MODULE$.require(seq.nonEmpty());
        Predef$.MODULE$.require(seq.forall(new DslImpl$$anonfun$org$tensorframes$dsl$DslImpl$$commonShape$2(seq)), new DslImpl$$anonfun$org$tensorframes$dsl$DslImpl$$commonShape$1(seq));
        return (Shape) seq.head();
    }

    public long org$tensorframes$dsl$DslImpl$$U() {
        return this.org$tensorframes$dsl$DslImpl$$U;
    }

    /* JADX WARN: Code restructure failed: missing block: B:19:0x013f, code lost:
    
        throw new scala.MatchError(r0);
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public org.tensorframes.Shape broadcastShape(scala.collection.Seq<org.tensorframes.Shape> r8) {
        /*
            Method dump skipped, instructions count: 320
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.tensorframes.dsl.DslImpl$.broadcastShape(scala.collection.Seq):org.tensorframes.Shape");
    }

    public NumericType org$tensorframes$dsl$DslImpl$$commonType(Seq<NumericType> seq) {
        Predef$.MODULE$.require(seq.nonEmpty());
        Predef$.MODULE$.require(seq.forall(new DslImpl$$anonfun$org$tensorframes$dsl$DslImpl$$commonType$2(seq)), new DslImpl$$anonfun$org$tensorframes$dsl$DslImpl$$commonType$1(seq));
        return (NumericType) seq.head();
    }

    public Node build(String str, @Nullable String str2, Seq<Node> seq, Function1<String, Seq<Node>> function1, @Nullable boolean z, @Nullable NumericType numericType, @Nullable Shape shape, Function1<Seq<NumericType>, NumericType> function12, Function1<Seq<Shape>, Shape> function13, Map<String, AttrValue> map) {
        return Node$.MODULE$.apply(Option$.MODULE$.apply(str2), str, (NumericType) Option$.MODULE$.apply(numericType).getOrElse(new DslImpl$$anonfun$7(seq, function12)), (Shape) Option$.MODULE$.apply(shape).getOrElse(new DslImpl$$anonfun$8(seq, function13)), seq, function1, z, map);
    }

    public String build$default$2() {
        return null;
    }

    public Seq<Node> build$default$3() {
        return Seq$.MODULE$.empty();
    }

    public Function1<String, Seq<Node>> build$default$4() {
        return new DslImpl$$anonfun$build$default$4$1();
    }

    public boolean build$default$5() {
        return true;
    }

    public NumericType build$default$6() {
        return null;
    }

    public Shape build$default$7() {
        return null;
    }

    public Function1<Seq<NumericType>, NumericType> build$default$8() {
        return new DslImpl$$anonfun$build$default$8$1();
    }

    public Function1<Seq<Shape>, Shape> build$default$9() {
        return new DslImpl$$anonfun$build$default$9$1();
    }

    public Map<String, AttrValue> build$default$10() {
        return Predef$.MODULE$.Map().empty();
    }

    public Node reduce_min(Node node, Seq<Object> seq, String str) {
        return build_reducer("Min", node, seq, str);
    }

    public Seq<Object> reduce_min$default$2() {
        return null;
    }

    public String reduce_min$default$3() {
        return null;
    }

    public Node reduce_sum(Node node, Seq<Object> seq, String str) {
        return build_reducer("Sum", node, seq, str);
    }

    public Seq<Object> reduce_sum$default$2() {
        return null;
    }

    public String reduce_sum$default$3() {
        return null;
    }

    public Node build_reducer(String str, Node node, Seq<Object> seq, String str2) {
        Operation named = package$.MODULE$.constant(seq, sequenceTensor1Conversion(Numeric$IntIsIntegral$.MODULE$, scala.reflect.runtime.package$.MODULE$.universe().TypeTag().Int(), IntConversion())).named(new StringBuilder().append(node.name()).append("/reduction_indices").toString());
        AttrValue.newBuilder().setB(false).build();
        return build(str, str2, (Seq) Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Node[]{node, package$.MODULE$.op2Node(named)})), build$default$4(), build$default$5(), node.scalarType(), reduce_shape(node.shape(), (Seq) Option$.MODULE$.apply(seq).getOrElse(new DslImpl$$anonfun$9())), build$default$8(), build$default$9(), (Map) Predef$.MODULE$.Map().apply(Predef$.MODULE$.wrapRefArray(new Tuple2[]{Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("Tidx"), AttrValue.newBuilder().setType(DataType.DT_INT32).build()), Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc("keep_dims"), AttrValue.newBuilder().setB(false).build())})));
    }

    public Seq<Object> build_reducer$default$3() {
        return null;
    }

    public String build_reducer$default$4() {
        return null;
    }

    private Shape reduce_shape(Shape shape, Seq<Object> seq) {
        Predef$.MODULE$.require(shape.numDims() >= seq.size());
        if (seq.isEmpty()) {
            return Shape$.MODULE$.empty();
        }
        return Shape$.MODULE$.apply((Seq<Object>) shape.dims().indices().filterNot(new DslImpl$$anonfun$10(seq)));
    }

    private DslImpl$() {
        MODULE$ = this;
        LazyLogging.class.$init$(this);
        Logging.Cclass.$init$(this);
        DefaultConversions.Cclass.$init$(this);
        this.org$tensorframes$dsl$DslImpl$$U = package$.MODULE$.Unknown();
    }
}
