package com.yahoo.tensor.functions;

import com.yahoo.api.annotations.Beta;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:com/yahoo/tensor/functions/Slice.class */
public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argument;
    private final List<DimensionValue<NAMETYPE>> subspaceAddress;

    /* loaded from: input_file:com/yahoo/tensor/functions/Slice$ConstantIntegerFunction.class */
    private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {
        private final int value;

        public ConstantIntegerFunction(int i) {
            this.value = i;
        }

        @Override // com.yahoo.tensor.functions.ScalarFunction, java.util.function.Function
        public Double apply(EvaluationContext<NAMETYPE> evaluationContext) {
            return Double.valueOf(this.value);
        }

        public String toString() {
            return String.valueOf(this.value);
        }

        public int hashCode() {
            return Objects.hash("constantIntegerFunction", Integer.valueOf(this.value));
        }
    }

    /* loaded from: input_file:com/yahoo/tensor/functions/Slice$DimensionValue.class */
    public static class DimensionValue<NAMETYPE extends Name> {
        private final Optional<String> dimension;
        private final String label;
        private final ScalarFunction<NAMETYPE> index;

        public DimensionValue(String str, String str2) {
            this(Optional.of(str), str2, null);
        }

        public DimensionValue(String str, int i) {
            this(Optional.of(str), null, new ConstantIntegerFunction(i));
        }

        public DimensionValue(int i) {
            this(Optional.empty(), null, new ConstantIntegerFunction(i));
        }

        public DimensionValue(String str) {
            this(Optional.empty(), str, null);
        }

        public DimensionValue(ScalarFunction<NAMETYPE> scalarFunction) {
            this(Optional.empty(), null, scalarFunction);
        }

        public DimensionValue(Optional<String> optional, String str) {
            this(optional, str, null);
        }

        public DimensionValue(Optional<String> optional, ScalarFunction<NAMETYPE> scalarFunction) {
            this(optional, null, scalarFunction);
        }

        public DimensionValue(String str, ScalarFunction<NAMETYPE> scalarFunction) {
            this(Optional.of(str), null, scalarFunction);
        }

        private DimensionValue(Optional<String> optional, String str, ScalarFunction<NAMETYPE> scalarFunction) {
            this.dimension = optional;
            this.label = str;
            this.index = scalarFunction;
        }

        public Optional<String> dimension() {
            return this.dimension;
        }

        public Optional<String> label() {
            return Optional.ofNullable(this.label);
        }

        public Optional<ScalarFunction<NAMETYPE>> index() {
            return Optional.ofNullable(this.index);
        }

        public String toString() {
            return toString(null, null);
        }

        String toString(ToStringContext<NAMETYPE> toStringContext, Slice<NAMETYPE> slice) {
            StringBuilder sb = new StringBuilder();
            Optional<String> optional = this.dimension;
            if (toStringContext != null && optional.isEmpty()) {
                TensorType type = toStringContext.typeContext().isPresent() ? ((Slice) slice).argument.type(toStringContext.typeContext().get()) : null;
                if (type == null || type.dimensions().size() != 1) {
                    throw new IllegalArgumentException("The tensor dimension name being sliced by " + slice + " cannot be uniquely resolved. Use the full form: 'slice{myDimensionName:" + valueToString(toStringContext) + "}'");
                }
                optional = Optional.of(type.dimensions().get(0).name());
            }
            optional.ifPresent(str -> {
                sb.append(str).append(":");
            });
            sb.append(valueToString(toStringContext));
            return sb.toString();
        }

        private String valueToString(ToStringContext<NAMETYPE> toStringContext) {
            return this.label != null ? TensorAddress.labelToString(this.label) : this.index.toString(toStringContext);
        }

        public int hashCode() {
            return Objects.hash(this.dimension, this.label, this.index);
        }
    }

    public Slice(TensorFunction<NAMETYPE> tensorFunction, List<DimensionValue<NAMETYPE>> list) {
        this.argument = (TensorFunction) Objects.requireNonNull(tensorFunction, "Argument cannot be null");
        if (list.size() > 1 && list.stream().anyMatch(dimensionValue -> {
            return dimensionValue.dimension().isEmpty();
        })) {
            throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: Specify dimension names explicitly instead");
        }
        this.subspaceAddress = list;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argument);
    }

    public List<TensorFunction<NAMETYPE>> selectorFunctions() {
        ArrayList arrayList = new ArrayList();
        Iterator<DimensionValue<NAMETYPE>> it = this.subspaceAddress.iterator();
        while (it.hasNext()) {
            Optional<U> flatMap = it.next().index().flatMap((v0) -> {
                return v0.asTensorFunction();
            });
            Objects.requireNonNull(arrayList);
            flatMap.ifPresent((v1) -> {
                r1.add(v1);
            });
        }
        return arrayList;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction<NAMETYPE> withTransformedFunctions(Function<ScalarFunction<NAMETYPE>, ScalarFunction<NAMETYPE>> function) {
        ArrayList arrayList = new ArrayList();
        for (DimensionValue<NAMETYPE> dimensionValue : this.subspaceAddress) {
            Optional<ScalarFunction<NAMETYPE>> index = dimensionValue.index();
            if (index.isPresent()) {
                arrayList.add(new DimensionValue(dimensionValue.dimension(), function.apply(index.get())));
            } else {
                arrayList.add(dimensionValue);
            }
        }
        return new Slice(this.argument, arrayList);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> list) {
        if (list.size() != 1) {
            throw new IllegalArgumentException("Value takes exactly one argument but got " + list.size());
        }
        return new Slice<>(list.get(0), this.subspaceAddress);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return this;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext<NAMETYPE> evaluationContext) {
        Tensor evaluate = this.argument.evaluate(evaluationContext);
        TensorType resultType = resultType(evaluate.type());
        PartialAddress subspaceToAddress = subspaceToAddress(evaluate.type(), evaluationContext);
        if (resultType.rank() == 0) {
            return Tensor.from(evaluate.get(subspaceToAddress.asAddress(evaluate.type())));
        }
        Tensor.Builder of = Tensor.Builder.of(resultType);
        Iterator<Tensor.Cell> cellIterator = evaluate.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            if (matches(subspaceToAddress, next.getKey(), evaluate.type())) {
                of.cell(remaining(resultType, next.getKey(), evaluate.type()), next.getValue().doubleValue());
            }
        }
        return of.build();
    }

    private PartialAddress subspaceToAddress(TensorType tensorType, EvaluationContext<NAMETYPE> evaluationContext) {
        PartialAddress.Builder builder = new PartialAddress.Builder(this.subspaceAddress.size());
        for (int i = 0; i < this.subspaceAddress.size(); i++) {
            if (this.subspaceAddress.get(i).label().isPresent()) {
                builder.add(this.subspaceAddress.get(i).dimension().orElse(tensorType.dimensions().get(i).name()), this.subspaceAddress.get(i).label().get());
            } else {
                builder.add(this.subspaceAddress.get(i).dimension().orElse(tensorType.dimensions().get(i).name()), this.subspaceAddress.get(i).index().get().apply((EvaluationContext) evaluationContext).intValue());
            }
        }
        return builder.build();
    }

    private boolean matches(PartialAddress partialAddress, TensorAddress tensorAddress, TensorType tensorType) {
        for (int i = 0; i < partialAddress.size(); i++) {
            if (!tensorAddress.objectLabel(tensorType.indexOfDimension(partialAddress.dimension(i)).get().intValue()).isEqualTo(partialAddress.objectLabel(i))) {
                return false;
            }
        }
        return true;
    }

    private TensorAddress remaining(TensorType tensorType, TensorAddress tensorAddress, TensorType tensorType2) {
        TensorAddress.Builder builder = new TensorAddress.Builder(tensorType);
        for (int i = 0; i < tensorAddress.size(); i++) {
            String name = tensorType2.dimensions().get(i).name();
            if (tensorType.dimension(tensorType2.dimensions().get(i).name()).isPresent()) {
                builder.add(name, tensorAddress.objectLabel(i));
            }
        }
        return builder.build();
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorType type(TypeContext<NAMETYPE> typeContext) {
        return resultType(this.argument.type(typeContext));
    }

    private List<String> findDimensions(List<TensorType.Dimension> list, Predicate<TensorType.Dimension> predicate) {
        return list.stream().filter(predicate).map((v0) -> {
            return v0.name();
        }).toList();
    }

    private TensorType resultType(TensorType tensorType) {
        List<String> list;
        if (this.subspaceAddress.size() != 1 || !this.subspaceAddress.get(0).dimension().isEmpty()) {
            list = this.subspaceAddress.stream().map(dimensionValue -> {
                return dimensionValue.dimension().get();
            }).toList();
        } else if (this.subspaceAddress.get(0).index().isPresent()) {
            list = findDimensions(tensorType.dimensions(), (v0) -> {
                return v0.isIndexed();
            });
            if (list.size() > 1) {
                throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied to " + tensorType + ", which has multiple");
            }
        } else {
            list = findDimensions(tensorType.dimensions(), (v0) -> {
                return v0.isMapped();
            });
            if (list.size() > 1) {
                throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied to " + tensorType + ", which has multiple");
            }
        }
        try {
            return TypeResolver.peek(tensorType, list);
        } catch (IllegalArgumentException e) {
            throw new IllegalArgumentException(this + " cannot slice type " + tensorType, e);
        }
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext<NAMETYPE> toStringContext) {
        StringBuilder sb = new StringBuilder(this.argument.toString(toStringContext));
        if (!toStringContext.typeContext().isEmpty() || this.subspaceAddress.size() != 1 || !this.subspaceAddress.get(0).dimension().isEmpty()) {
            sb.append("{").append((String) this.subspaceAddress.stream().map(dimensionValue -> {
                return dimensionValue.toString(toStringContext, this);
            }).collect(Collectors.joining(", "))).append("}");
        } else if (this.subspaceAddress.get(0).index().isPresent()) {
            sb.append("[").append(this.subspaceAddress.get(0).index().get().toString(toStringContext)).append("]");
        } else {
            sb.append("{").append(this.subspaceAddress.get(0).label().get()).append("}");
        }
        return sb.toString();
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public int hashCode() {
        return Objects.hash("slice", this.argument, this.subspaceAddress);
    }
}
