/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.CellOrder;
import com.yahoo.tensor.functions.CompositeTensorFunction;
import com.yahoo.tensor.functions.FilterSubspaces;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.List;
import java.util.Objects;

public class Top<NAMETYPE extends Name>
extends CompositeTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> n;
    private final TensorFunction<NAMETYPE> input;

    public Top(TensorFunction<NAMETYPE> n, TensorFunction<NAMETYPE> input) {
        this.n = n;
        this.input = input;
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        TensorType nType = this.n.type(context);
        TensorType inputType = this.input.type(context);
        if (nType.rank() > 0) {
            throw new IllegalArgumentException("the N argument to top(N,input) should be a number, but had type: " + String.valueOf(nType));
        }
        if (inputType.hasIndexedDimensions() || !inputType.hasMappedDimensions()) {
            throw new IllegalArgumentException("the input argument to top(N,input) should be a sparse tensor, but had type: " + String.valueOf(inputType));
        }
        return super.type(context);
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.n, this.input);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("Top must have 2 arguments, got " + arguments.size());
        }
        return new Top<NAMETYPE>(arguments.get(0), arguments.get(1));
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        PrimitiveTensorFunction<NAMETYPE> primitiveI = this.input.toPrimitive();
        PrimitiveTensorFunction<NAMETYPE> primitiveN = this.n.toPrimitive();
        CellOrder<NAMETYPE> ranks = new CellOrder<NAMETYPE>(primitiveI, CellOrder.Order.MAX);
        Join<NAMETYPE> masks = new Join<NAMETYPE>(ranks, primitiveN, ScalarFunctions.less());
        FilterSubspaces<NAMETYPE> filter = new FilterSubspaces<NAMETYPE>(masks, "s", new VariableTensor("s"));
        Join<NAMETYPE> result = new Join<NAMETYPE>(primitiveI, filter, ScalarFunctions.multiply());
        return result;
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "top(" + this.n.toString(context) + ", " + this.input.toString(context) + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash("top_n", this.n, this.input);
    }
}

