package com.yahoo.tensor;

import com.google.common.annotations.Beta;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.Argmin;
import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Diag;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.L1Normalize;
import com.yahoo.tensor.functions.L2Normalize;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Random;
import com.yahoo.tensor.functions.Range;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;

@Beta
/* loaded from: input_file:com/yahoo/tensor/Tensor.class */
public interface Tensor {

    /* loaded from: input_file:com/yahoo/tensor/Tensor$Builder.class */
    public interface Builder {

        /* loaded from: input_file:com/yahoo/tensor/Tensor$Builder$CellBuilder.class */
        public static class CellBuilder {
            private final TensorAddress.Builder addressBuilder;
            private final Builder tensorBuilder;

            /* JADX INFO: Access modifiers changed from: package-private */
            public CellBuilder(TensorType tensorType, Builder builder) {
                this.addressBuilder = new TensorAddress.Builder(tensorType);
                this.tensorBuilder = builder;
            }

            public CellBuilder label(String str, String str2) {
                this.addressBuilder.add(str, str2);
                return this;
            }

            public CellBuilder label(String str, int i) {
                return label(str, String.valueOf(i));
            }

            public Builder value(double d) {
                return this.tensorBuilder.cell(this.addressBuilder.build(), d);
            }
        }

        static Builder of(TensorType tensorType) {
            boolean anyMatch = tensorType.dimensions().stream().anyMatch(dimension -> {
                return dimension.isIndexed();
            });
            boolean anyMatch2 = tensorType.dimensions().stream().anyMatch(dimension2 -> {
                return !dimension2.isIndexed();
            });
            if (anyMatch && anyMatch2) {
                throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet");
            }
            return anyMatch2 ? MappedTensor.Builder.of(tensorType) : IndexedTensor.Builder.of(tensorType);
        }

        static Builder of(TensorType tensorType, DimensionSizes dimensionSizes) {
            boolean anyMatch = tensorType.dimensions().stream().anyMatch(dimension -> {
                return dimension.isIndexed();
            });
            boolean anyMatch2 = tensorType.dimensions().stream().anyMatch(dimension2 -> {
                return !dimension2.isIndexed();
            });
            if (anyMatch && anyMatch2) {
                throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet");
            }
            return anyMatch2 ? MappedTensor.Builder.of(tensorType) : IndexedTensor.Builder.of(tensorType, dimensionSizes);
        }

        TensorType type();

        CellBuilder cell();

        Builder cell(TensorAddress tensorAddress, double d);

        Builder cell(double d, int... iArr);

        default Builder cell(Cell cell, double d) {
            return cell(cell.getKey(), d);
        }

        Tensor build();
    }

    /* loaded from: input_file:com/yahoo/tensor/Tensor$Cell.class */
    public static class Cell implements Map.Entry<TensorAddress, Double> {
        private final TensorAddress address;
        private final Double value;

        /* JADX INFO: Access modifiers changed from: package-private */
        public Cell(TensorAddress tensorAddress, Double d) {
            this.address = tensorAddress;
            this.value = d;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Map.Entry
        public TensorAddress getKey() {
            return this.address;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public int getDirectIndex() {
            return -1;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Map.Entry
        public Double getValue() {
            return this.value;
        }

        @Override // java.util.Map.Entry
        public Double setValue(Double d) {
            throw new UnsupportedOperationException("A tensor cannot be modified");
        }

        @Override // java.util.Map.Entry
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Map.Entry)) {
                return false;
            }
            Map.Entry entry = (Map.Entry) obj;
            return getValue().equals(entry.getValue()) && getKey().equals(entry.getKey());
        }

        @Override // java.util.Map.Entry
        public int hashCode() {
            return getKey().hashCode() ^ getValue().hashCode();
        }
    }

    TensorType type();

    default boolean isEmpty() {
        return size() == 0;
    }

    int size();

    double get(TensorAddress tensorAddress);

    Iterator<Cell> cellIterator();

    Iterator<Double> valueIterator();

    Map<TensorAddress, Double> cells();

    default double asDouble() {
        if (type().dimensions().size() > 0) {
            throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + type().dimensions().size());
        }
        if (size() == 0) {
            return Double.NaN;
        }
        return valueIterator().next().doubleValue();
    }

    default Tensor map(DoubleUnaryOperator doubleUnaryOperator) {
        return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), doubleUnaryOperator).evaluate();
    }

    default Tensor reduce(Reduce.Aggregator aggregator, String... strArr) {
        return new Reduce(new ConstantTensor(this), aggregator, (List<String>) Arrays.asList(strArr)).evaluate();
    }

    default Tensor reduce(Reduce.Aggregator aggregator, List<String> list) {
        return new Reduce(new ConstantTensor(this), aggregator, list).evaluate();
    }

    default Tensor join(Tensor tensor, DoubleBinaryOperator doubleBinaryOperator) {
        return new Join(new ConstantTensor(this), new ConstantTensor(tensor), doubleBinaryOperator).evaluate();
    }

    default Tensor rename(String str, String str2) {
        return new Rename(new ConstantTensor(this), Collections.singletonList(str), Collections.singletonList(str2)).evaluate();
    }

    default Tensor concat(Tensor tensor, String str) {
        return new Concat(new ConstantTensor(this), new ConstantTensor(tensor), str).evaluate();
    }

    default Tensor rename(List<String> list, List<String> list2) {
        return new Rename(new ConstantTensor(this), list, list2).evaluate();
    }

    static Tensor generate(TensorType tensorType, Function<List<Integer>, Double> function) {
        return new Generate(tensorType, function).evaluate();
    }

    default Tensor l1Normalize(String str) {
        return new L1Normalize(new ConstantTensor(this), str).evaluate();
    }

    default Tensor l2Normalize(String str) {
        return new L2Normalize(new ConstantTensor(this), str).evaluate();
    }

    default Tensor matmul(Tensor tensor, String str) {
        return new Matmul(new ConstantTensor(this), new ConstantTensor(tensor), str).evaluate();
    }

    default Tensor softmax(String str) {
        return new Softmax(new ConstantTensor(this), str).evaluate();
    }

    default Tensor xwPlusB(Tensor tensor, Tensor tensor2, String str) {
        return new XwPlusB(new ConstantTensor(this), new ConstantTensor(tensor), new ConstantTensor(tensor2), str).evaluate();
    }

    default Tensor argmax(String str) {
        return new Argmax(new ConstantTensor(this), str).evaluate();
    }

    default Tensor argmin(String str) {
        return new Argmin(new ConstantTensor(this), str).evaluate();
    }

    static Tensor diag(TensorType tensorType) {
        return new Diag(tensorType).evaluate();
    }

    static Tensor random(TensorType tensorType) {
        return new Random(tensorType).evaluate();
    }

    static Tensor range(TensorType tensorType) {
        return new Range(tensorType).evaluate();
    }

    default Tensor multiply(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d * d2;
        });
    }

    default Tensor add(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d + d2;
        });
    }

    default Tensor divide(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d / d2;
        });
    }

    default Tensor subtract(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d - d2;
        });
    }

    default Tensor max(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d > d2 ? d : d2;
        });
    }

    default Tensor min(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d < d2 ? d : d2;
        });
    }

    default Tensor atan2(Tensor tensor) {
        return join(tensor, Math::atan2);
    }

    default Tensor larger(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d > d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor largerOrEqual(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d >= d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor smaller(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d < d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor smallerOrEqual(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d <= d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor equal(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d == d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor notEqual(Tensor tensor) {
        return join(tensor, (d, d2) -> {
            return d != d2 ? 1.0d : 0.0d;
        });
    }

    default Tensor avg(List<String> list) {
        return reduce(Reduce.Aggregator.avg, list);
    }

    default Tensor count(List<String> list) {
        return reduce(Reduce.Aggregator.count, list);
    }

    default Tensor max(List<String> list) {
        return reduce(Reduce.Aggregator.max, list);
    }

    default Tensor min(List<String> list) {
        return reduce(Reduce.Aggregator.min, list);
    }

    default Tensor prod(List<String> list) {
        return reduce(Reduce.Aggregator.prod, list);
    }

    default Tensor sum(List<String> list) {
        return reduce(Reduce.Aggregator.sum, list);
    }

    String toString();

    static String toStandardString(Tensor tensor) {
        return (!tensor.isEmpty() || tensor.type().dimensions().isEmpty()) ? contentToString(tensor) : tensor.type() + ":" + contentToString(tensor);
    }

    static String contentToString(Tensor tensor) {
        ArrayList<Map.Entry> arrayList = new ArrayList(tensor.cells().entrySet());
        if (tensor.type().dimensions().isEmpty()) {
            return arrayList.isEmpty() ? "{}" : "{" + ((Map.Entry) arrayList.get(0)).getValue() + "}";
        }
        Collections.sort(arrayList, Map.Entry.comparingByKey());
        StringBuilder sb = new StringBuilder("{");
        for (Map.Entry entry : arrayList) {
            sb.append(((TensorAddress) entry.getKey()).toString(tensor.type())).append(":").append(entry.getValue());
            sb.append(",");
        }
        if (sb.length() > 1) {
            sb.setLength(sb.length() - 1);
        }
        sb.append("}");
        return sb.toString();
    }

    boolean equals(Object obj);

    static boolean equals(Tensor tensor, Tensor tensor2) {
        if (tensor == tensor2) {
            return true;
        }
        if (!tensor.type().mathematicallyEquals(tensor2.type()) || tensor.size() != tensor2.size()) {
            return false;
        }
        Iterator<Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Cell next = cellIterator.next();
            if (!next.getValue().equals(Double.valueOf(tensor2.get(next.getKey())))) {
                return false;
            }
        }
        return true;
    }

    static Tensor from(TensorType tensorType, String str) {
        return TensorParser.tensorFrom(str, Optional.of(tensorType));
    }

    static Tensor from(String str, String str2) {
        return TensorParser.tensorFrom(str2, Optional.of(TensorType.fromSpec(str)));
    }

    static Tensor from(String str) {
        return TensorParser.tensorFrom(str, Optional.empty());
    }
}
