package com.yahoo.tensor;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.Argmin;
import com.yahoo.tensor.functions.CellCast;
import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Diag;
import com.yahoo.tensor.functions.Expand;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.L1Normalize;
import com.yahoo.tensor.functions.L2Normalize;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Merge;
import com.yahoo.tensor.functions.Random;
import com.yahoo.tensor.functions.Range;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;

/* loaded from: input_file:com/yahoo/tensor/Tensor.class */
public interface Tensor {

    /* loaded from: input_file:com/yahoo/tensor/Tensor$Builder.class */
    public interface Builder {

        /* loaded from: input_file:com/yahoo/tensor/Tensor$Builder$CellBuilder.class */
        public static class CellBuilder {
            private final TensorAddress.Builder addressBuilder;
            private final Builder tensorBuilder;

            /* JADX INFO: Access modifiers changed from: package-private */
            public CellBuilder(TensorType tensorType, Builder builder) {
                this.addressBuilder = new TensorAddress.Builder(tensorType);
                this.tensorBuilder = builder;
            }

            public CellBuilder label(String str, String str2) {
                this.addressBuilder.add(str, str2);
                return this;
            }

            public TensorType type() {
                return this.tensorBuilder.type();
            }

            public CellBuilder label(String str, long j) {
                return label(str, String.valueOf(j));
            }

            public Builder value(double d) {
                return this.tensorBuilder.cell(this.addressBuilder.build(), d);
            }

            public Builder value(float f) {
                return this.tensorBuilder.cell(this.addressBuilder.build(), f);
            }
        }

        static Builder of(String str) {
            return of(TensorType.fromSpec(str));
        }

        static Builder of(TensorType tensorType) {
            boolean anyMatch = tensorType.dimensions().stream().anyMatch((v0) -> {
                return v0.isIndexed();
            });
            boolean anyMatch2 = tensorType.dimensions().stream().anyMatch(dimension -> {
                return !dimension.isIndexed();
            });
            return (anyMatch && anyMatch2) ? MixedTensor.Builder.of(tensorType) : anyMatch2 ? MappedTensor.Builder.of(tensorType) : IndexedTensor.Builder.of(tensorType);
        }

        static Builder of(TensorType tensorType, DimensionSizes dimensionSizes) {
            boolean anyMatch = tensorType.dimensions().stream().anyMatch((v0) -> {
                return v0.isIndexed();
            });
            boolean anyMatch2 = tensorType.dimensions().stream().anyMatch(dimension -> {
                return !dimension.isIndexed();
            });
            return (anyMatch && anyMatch2) ? MixedTensor.Builder.of(tensorType) : anyMatch2 ? MappedTensor.Builder.of(tensorType) : IndexedTensor.Builder.of(tensorType, dimensionSizes);
        }

        TensorType type();

        CellBuilder cell();

        Builder cell(TensorAddress tensorAddress, double d);

        Builder cell(TensorAddress tensorAddress, float f);

        Builder cell(double d, long... jArr);

        Builder cell(float f, long... jArr);

        default Builder cell(Cell cell, double d) {
            return cell(cell.getKey(), d);
        }

        default Builder cell(Cell cell, float f) {
            return cell(cell.getKey(), f);
        }

        default Builder cell(Cell cell) {
            return cell(cell.getKey(), cell.getValue().doubleValue());
        }

        Tensor build();
    }

    /* loaded from: input_file:com/yahoo/tensor/Tensor$Cell.class */
    public static class Cell implements Map.Entry<TensorAddress, Double> {
        private final TensorAddress address;
        private final double value;

        /* JADX INFO: Access modifiers changed from: package-private */
        public Cell(TensorAddress tensorAddress, Number number) {
            this(tensorAddress, number.doubleValue());
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Cell(TensorAddress tensorAddress, double d) {
            this.address = tensorAddress;
            this.value = d;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Map.Entry
        public TensorAddress getKey() {
            return this.address;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public long getDirectIndex() {
            return -1L;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Map.Entry
        public Double getValue() {
            return Double.valueOf(this.value);
        }

        public float getFloatValue() {
            return getValue().floatValue();
        }

        public double getDoubleValue() {
            return getValue().doubleValue();
        }

        @Override // java.util.Map.Entry
        public Double setValue(Double d) {
            throw new UnsupportedOperationException("A tensor cannot be modified");
        }

        @Override // java.util.Map.Entry
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Map.Entry)) {
                return false;
            }
            Map.Entry entry = (Map.Entry) obj;
            return getValue().equals(entry.getValue()) && getKey().equals(entry.getKey());
        }

        @Override // java.util.Map.Entry
        public int hashCode() {
            return getKey().hashCode() ^ getValue().hashCode();
        }

        public String toString(TensorType tensorType) {
            return this.address.toString(tensorType) + ":" + this.value;
        }

        public Cell detach() {
            return this;
        }
    }

    TensorType type();

    default boolean isEmpty() {
        return size() == 0;
    }

    long size();

    double get(TensorAddress tensorAddress);

    boolean has(TensorAddress tensorAddress);

    Iterator<Cell> cellIterator();

    Iterator<Double> valueIterator();

    Map<TensorAddress, Double> cells();

    default double asDouble() {
        if (type().dimensions().size() > 0) {
            throw new IllegalStateException("Require a dimensionless tensor but has " + type());
        }
        if (size() == 0) {
            return Double.NaN;
        }
        return valueIterator().next().doubleValue();
    }

    Tensor withType(TensorType tensorType);

    default Tensor modify(DoubleBinaryOperator doubleBinaryOperator, Map<TensorAddress, Double> map) {
        Builder of = Builder.of(type());
        Iterator<Cell> cellIterator = cellIterator();
        while (cellIterator.hasNext()) {
            Cell next = cellIterator.next();
            TensorAddress key = next.getKey();
            double doubleValue = next.getValue().doubleValue();
            of.cell(key, map.containsKey(key) ? doubleBinaryOperator.applyAsDouble(doubleValue, map.get(key).doubleValue()) : doubleValue);
        }
        return of.build();
    }

    Tensor remove(Set<TensorAddress> set);

    default Tensor map(DoubleUnaryOperator doubleUnaryOperator) {
        return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), doubleUnaryOperator).evaluate();
    }

    default Tensor reduce(Reduce.Aggregator aggregator, String... strArr) {
        return new Reduce(new ConstantTensor(this), aggregator, (List<String>) Arrays.asList(strArr)).evaluate();
    }

    default Tensor reduce(Reduce.Aggregator aggregator, List<String> list) {
        return new Reduce(new ConstantTensor(this), aggregator, list).evaluate();
    }

    default Tensor join(Tensor tensor, DoubleBinaryOperator doubleBinaryOperator) {
        return new Join(new ConstantTensor(this), new ConstantTensor(tensor), doubleBinaryOperator).evaluate();
    }

    default Tensor merge(Tensor tensor, DoubleBinaryOperator doubleBinaryOperator) {
        return new Merge(new ConstantTensor(this), new ConstantTensor(tensor), doubleBinaryOperator).evaluate();
    }

    default Tensor rename(String str, String str2) {
        return new Rename(new ConstantTensor(this), (List<String>) Collections.singletonList(str), (List<String>) Collections.singletonList(str2)).evaluate();
    }

    default Tensor concat(double d, String str) {
        return concat(Builder.of(TensorType.empty).cell(d, new long[0]).build(), str);
    }

    default Tensor concat(Tensor tensor, String str) {
        return new Concat(new ConstantTensor(this), new ConstantTensor(tensor), str).evaluate();
    }

    default Tensor rename(List<String> list, List<String> list2) {
        return new Rename(new ConstantTensor(this), list, list2).evaluate();
    }

    static Tensor generate(TensorType tensorType, Function<List<Long>, Double> function) {
        return new Generate(tensorType, function).evaluate();
    }

    default Tensor cellCast(TensorType.Value value) {
        return new CellCast(new ConstantTensor(this), value).evaluate();
    }

    default Tensor l1Normalize(String str) {
        return new L1Normalize(new ConstantTensor(this), str).evaluate();
    }

    default Tensor l2Normalize(String str) {
        return new L2Normalize(new ConstantTensor(this), str).evaluate();
    }

    default Tensor matmul(Tensor tensor, String str) {
        return new Matmul(new ConstantTensor(this), new ConstantTensor(tensor), str).evaluate();
    }

    default Tensor softmax(String str) {
        return new Softmax(new ConstantTensor(this), str).evaluate();
    }

    default Tensor xwPlusB(Tensor tensor, Tensor tensor2, String str) {
        return new XwPlusB(new ConstantTensor(this), new ConstantTensor(tensor), new ConstantTensor(tensor2), str).evaluate();
    }

    default Tensor expand(String str) {
        return new Expand(new ConstantTensor(this), str).evaluate();
    }

    default Tensor argmax(String str) {
        return new Argmax(new ConstantTensor(this), str).evaluate();
    }

    default Tensor argmin(String str) {
        return new Argmin(new ConstantTensor(this), str).evaluate();
    }

    static Tensor diag(TensorType tensorType) {
        return new Diag(tensorType).evaluate();
    }

    static Tensor random(TensorType tensorType) {
        return new Random(tensorType).evaluate();
    }

    static Tensor range(TensorType tensorType) {
        return new Range(tensorType).evaluate();
    }

    default Tensor multiply(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d * d2;
        });
    }

    default Tensor add(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d + d2;
        });
    }

    default Tensor divide(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d / d2;
        });
    }

    default Tensor subtract(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d - d2;
        });
    }

    default Tensor max(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d > d2 ? d : d2;
        });
    }

    default Tensor min(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d < d2 ? d : d2;
        });
    }

    default Tensor atan2(Tensor tensor) {
        return join(tensor, Math::atan2);
    }

    default Tensor pow(Tensor tensor) {
        return join(tensor, Math::pow);
    }

    default Tensor fmod(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d % d2;
        });
    }

    default Tensor ldexp(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d * Math.pow(2.0d, (int) d2);
        });
    }

    default Tensor larger(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d > d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor largerOrEqual(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d >= d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor smaller(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d < d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor smallerOrEqual(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d <= d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor equal(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d == d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor notEqual(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d != d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor approxEqual(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return approxEquals(d, d2) ? 1.0d : 0.0d;
        });
    }

    default Tensor bit(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return (((int) d2) >= 8 || ((int) d2) < 0 || (((int) d) & (1 << ((int) d2))) == 0) ? 0.0d : 1.0d;
        });
    }

    default Tensor hamming(Tensor tensor) {
        return join(tensor, ScalarFunctions.Hamming::hamming);
    }

    default Tensor avg() {
        return avg(Collections.emptyList());
    }

    default Tensor avg(String str) {
        return avg(Collections.singletonList(str));
    }

    default Tensor avg(List<String> list) {
        return reduce(Reduce.Aggregator.avg, list);
    }

    default Tensor count() {
        return count(Collections.emptyList());
    }

    default Tensor count(String str) {
        return count(Collections.singletonList(str));
    }

    default Tensor count(List<String> list) {
        return reduce(Reduce.Aggregator.count, list);
    }

    default Tensor max() {
        return max(Collections.emptyList());
    }

    default Tensor max(String str) {
        return max(Collections.singletonList(str));
    }

    default Tensor max(List<String> list) {
        return reduce(Reduce.Aggregator.max, list);
    }

    default Tensor median() {
        return median(Collections.emptyList());
    }

    default Tensor median(String str) {
        return median(Collections.singletonList(str));
    }

    default Tensor median(List<String> list) {
        return reduce(Reduce.Aggregator.median, list);
    }

    default Tensor min() {
        return min(Collections.emptyList());
    }

    default Tensor min(String str) {
        return min(Collections.singletonList(str));
    }

    default Tensor min(List<String> list) {
        return reduce(Reduce.Aggregator.min, list);
    }

    default Tensor prod() {
        return prod(Collections.emptyList());
    }

    default Tensor prod(String str) {
        return prod(Collections.singletonList(str));
    }

    default Tensor prod(List<String> list) {
        return reduce(Reduce.Aggregator.prod, list);
    }

    default Tensor sum() {
        return sum(Collections.emptyList());
    }

    default Tensor sum(String str) {
        return sum(Collections.singletonList(str));
    }

    default Tensor sum(List<String> list) {
        return reduce(Reduce.Aggregator.sum, list);
    }

    default List<Cell> largest() {
        ArrayList arrayList = new ArrayList(1);
        double d = Double.MIN_VALUE;
        Iterator<Cell> cellIterator = cellIterator();
        while (cellIterator.hasNext()) {
            Cell next = cellIterator.next();
            if (next.getValue().doubleValue() > d) {
                arrayList.clear();
                arrayList.add(next.detach());
                d = next.getDoubleValue();
            } else if (next.getValue().doubleValue() == d) {
                arrayList.add(next.detach());
            }
        }
        return arrayList;
    }

    default List<Cell> smallest() {
        ArrayList arrayList = new ArrayList(1);
        double d = Double.MAX_VALUE;
        Iterator<Cell> cellIterator = cellIterator();
        while (cellIterator.hasNext()) {
            Cell next = cellIterator.next();
            if (next.getValue().doubleValue() < d) {
                arrayList.clear();
                arrayList.add(next.detach());
                d = next.getDoubleValue();
            } else if (next.getValue().doubleValue() == d) {
                arrayList.add(next.detach());
            }
        }
        return arrayList;
    }

    String toString();

    String toString(boolean z, boolean z2);

    default String toAbbreviatedString() {
        return toAbbreviatedString(true, true);
    }

    String toAbbreviatedString(boolean z, boolean z2);

    static String toStandardString(Tensor tensor, boolean z, boolean z2, long j) {
        return (z ? tensor.type() + ":" : "") + valueToString(tensor, z2, j);
    }

    static String valueToString(Tensor tensor, boolean z, long j) {
        ArrayList arrayList = new ArrayList(tensor.cells().entrySet());
        arrayList.sort(Map.Entry.comparingByKey());
        if (tensor.type().dimensions().isEmpty()) {
            return arrayList.isEmpty() ? "{}" : "{" + ((Map.Entry) arrayList.get(0)).getValue() + "}";
        }
        StringBuilder sb = new StringBuilder("{");
        int i = 0;
        while (i < arrayList.size() && i < j) {
            if (i > 0) {
                sb.append(", ");
            }
            sb.append(cellToString((Map.Entry) arrayList.get(i), tensor.type(), z));
            i++;
        }
        if (i == j && i < tensor.size()) {
            sb.append(", ...");
        }
        sb.append("}");
        return sb.toString();
    }

    private static String cellToString(Map.Entry<TensorAddress, Double> entry, TensorType tensorType, boolean z) {
        return ((z && tensorType.rank() == 1) ? TensorAddress.labelToString(entry.getKey().label(0)) : entry.getKey().toString(tensorType)) + ":" + entry.getValue();
    }

    boolean equals(Object obj);

    int hashCode();

    static boolean equals(Tensor tensor, Tensor tensor2) {
        if (tensor == tensor2) {
            return true;
        }
        if (!tensor.type().mathematicallyEquals(tensor2.type()) || tensor.size() != tensor2.size()) {
            return false;
        }
        Iterator<Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Cell next = cellIterator.next();
            if (!approxEquals(next.getValue().doubleValue(), tensor2.get(next.getKey()), 1.0E-4d)) {
                return false;
            }
        }
        return true;
    }

    static boolean approxEquals(double d, double d2, double d3) {
        if (d == d2) {
            return true;
        }
        return (Double.isNaN(d) && Double.isNaN(d2)) || Math.abs(d - d2) < d3;
    }

    static boolean approxEquals(double d, double d2) {
        double nextAfter;
        if (d2 < -1.0d || d2 > 1.0d) {
            nextAfter = Math.nextAfter(d / d2, 1.0d);
            d2 = 1.0d;
        } else {
            nextAfter = Math.nextAfter(d, d2);
        }
        return nextAfter == d2;
    }

    static Tensor from(TensorType tensorType, String str) {
        return TensorParser.tensorFrom(str, Optional.of(tensorType));
    }

    static Tensor from(String str, String str2) {
        return TensorParser.tensorFrom(str2, Optional.of(TensorType.fromSpec(str)));
    }

    static Tensor from(String str) {
        return TensorParser.tensorFrom(str, Optional.empty());
    }

    static Tensor from(double d) {
        return Builder.of(TensorType.empty).cell(d, new long[0]).build();
    }
}
