/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

public class Split
extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributes;
    private final int output;
    private final int axis;
    private int start;
    private int end;

    public Split(String modelName, String nodeName, List<IntermediateOperation> inputs, IntermediateOperation.AttributeMap attributes, int output) {
        super(modelName, nodeName, inputs);
        this.attributes = attributes;
        this.output = output;
        this.axis = (int)attributes.get("axis").orElse((Value)DoubleValue.zero).asDouble();
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        int i;
        if (!this.allInputTypesPresent(1)) {
            return null;
        }
        OrderedTensorType inputType = ((IntermediateOperation)this.inputs.get(0)).type().get();
        ((IntermediateOperation)this.inputs.get((int)0)).exportAsRankingFunction = true;
        int axisSize = ((Long)inputType.dimensions().get(this.axis).size().get()).intValue();
        this.start = 0;
        this.end = axisSize;
        if (this.attributes.getList("split").isPresent()) {
            List<Value> splitList = this.attributes.getList("split").get();
            if (this.output > splitList.size()) {
                throw new IllegalArgumentException("Split in " + this.name + ": output out of range of split list");
            }
            for (i = 0; i < this.output; ++i) {
                this.start += (int)splitList.get(i).asDouble();
            }
            if (this.output < splitList.size()) {
                this.end = this.start + (int)splitList.get(this.output).asDouble();
            }
        } else {
            this.start = axisSize / 2 * this.output;
            this.end = this.start + axisSize / 2;
        }
        if (this.start >= axisSize || this.start < 0) {
            throw new IllegalArgumentException("Split in " + this.name + ": split start index out of range (" + this.start + ")");
        }
        if (this.end > axisSize || this.end < 0) {
            throw new IllegalArgumentException("Split in " + this.name + ": split end index out of range (" + this.end + ")");
        }
        OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(this.resultValueType());
        for (i = 0; i < inputType.rank(); ++i) {
            TensorType.Dimension inputDimension = inputType.dimensions().get(i);
            long dimSize = i == this.axis ? (long)(this.end - this.start) : (Long)inputDimension.size().get();
            typeBuilder.add(TensorType.Dimension.indexed((String)inputDimension.name(), (long)dimSize));
        }
        return typeBuilder.build();
    }

    @Override
    protected TensorFunction<Reference> lazyGetFunction() {
        if (!this.allInputFunctionsPresent(1)) {
            return null;
        }
        IntermediateOperation input = (IntermediateOperation)this.inputs.get(0);
        OrderedTensorType inputType = input.type().get();
        String inputFunctionName = input.rankingExpressionFunctionName();
        ArrayList<Slice.DimensionValue> dimensionValues = new ArrayList<Slice.DimensionValue>();
        for (int i = 0; i < inputType.rank(); ++i) {
            String inputDimensionName = inputType.dimensions().get(i).name();
            ReferenceNode reference = new ReferenceNode(inputDimensionName);
            OperationNode offset = new OperationNode((ExpressionNode)reference, Operator.plus, (ExpressionNode)new ConstantNode((Value)new DoubleValue(i == this.axis ? (double)this.start : 0.0)));
            dimensionValues.add(new Slice.DimensionValue(Optional.of(inputDimensionName), TensorFunctionNode.wrapScalar((ExpressionNode)new EmbracedNode((ExpressionNode)offset))));
        }
        TensorFunctionNode.ExpressionTensorFunction inputIndices = new TensorFunctionNode.ExpressionTensorFunction((ExpressionNode)new ReferenceNode(inputFunctionName));
        Slice sliceIndices = new Slice((TensorFunction)inputIndices, dimensionValues);
        TensorFunctionNode sliceExpression = new TensorFunctionNode((TensorFunction)sliceIndices);
        Generate generate = Generate.bound((TensorType)this.type.type(), (ScalarFunction)TensorFunctionNode.wrapScalar((ExpressionNode)sliceExpression));
        return generate;
    }

    @Override
    public Split withInputs(List<IntermediateOperation> inputs) {
        return new Split(this.modelName(), this.name(), inputs, this.attributes, this.output);
    }

    @Override
    public String operationName() {
        return "Split";
    }
}

