package com.yahoo.tensor.functions;

import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.CellOrder;
import java.util.List;
import java.util.Objects;

/* loaded from: input_file:com/yahoo/tensor/functions/Top.class */
public class Top<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> n;
    private final TensorFunction<NAMETYPE> input;

    public Top(TensorFunction<NAMETYPE> tensorFunction, TensorFunction<NAMETYPE> tensorFunction2) {
        this.n = tensorFunction;
        this.input = tensorFunction2;
    }

    @Override // com.yahoo.tensor.functions.CompositeTensorFunction, com.yahoo.tensor.functions.TensorFunction
    public TensorType type(TypeContext<NAMETYPE> typeContext) {
        TensorType type = this.n.type(typeContext);
        TensorType type2 = this.input.type(typeContext);
        if (type.rank() > 0) {
            throw new IllegalArgumentException("the N argument to top(N,input) should be a number, but had type: " + String.valueOf(type));
        }
        if (type2.hasIndexedDimensions() || !type2.hasMappedDimensions()) {
            throw new IllegalArgumentException("the input argument to top(N,input) should be a sparse tensor, but had type: " + String.valueOf(type2));
        }
        return super.type(typeContext);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.n, this.input);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("Top must have 2 arguments, got " + list.size());
        }
        return new Top(list.get(0), list.get(1));
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        PrimitiveTensorFunction<NAMETYPE> primitive = this.input.toPrimitive();
        return new Join(primitive, new FilterSubspaces(new Join(new CellOrder(primitive, CellOrder.Order.MAX), this.n.toPrimitive(), ScalarFunctions.less()), "s", new VariableTensor("s")), ScalarFunctions.multiply());
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext<NAMETYPE> toStringContext) {
        return "top(" + this.n.toString(toStringContext) + ", " + this.input.toString(toStringContext) + ")";
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public int hashCode() {
        return Objects.hash("top_n", this.n, this.input);
    }
}
