/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

public class CellCast<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argument;
    private final TensorType.Value valueType;

    public CellCast(TensorFunction<NAMETYPE> argument, TensorType.Value valueType) {
        Objects.requireNonNull(argument, "The argument tensor cannot be null");
        Objects.requireNonNull(valueType, "The value type cannot be null");
        this.argument = argument;
        this.valueType = valueType;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argument);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 1) {
            throw new IllegalArgumentException("CellCast must have 1 argument, got " + arguments.size());
        }
        return new CellCast<NAMETYPE>(arguments.get(0), this.valueType);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new CellCast<NAMETYPE>(this.argument.toPrimitive(), this.valueType);
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return TypeResolver.cell_cast(this.argument.type(context), this.valueType);
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor tensor = this.argument.evaluate(context);
        if (tensor.type().valueType() == this.valueType) {
            return tensor;
        }
        TensorType type = TypeResolver.cell_cast(tensor.type(), this.valueType);
        return this.cast(tensor, type);
    }

    private Tensor cast(Tensor tensor, TensorType type) {
        TensorType.Value fromValueType = tensor.type().valueType();
        switch (fromValueType) {
            case DOUBLE: {
                return this.castFromDouble(tensor, type);
            }
            case FLOAT: 
            case BFLOAT16: 
            case INT8: {
                return this.castFromSomeFloat(tensor, type);
            }
        }
        throw new IllegalStateException("Unexpected value type " + String.valueOf((Object)fromValueType));
    }

    private Tensor castFromDouble(Tensor tensor, TensorType type) {
        Tensor.Builder builder = Tensor.Builder.of(type);
        Function<Float, Float> restrict = CellCast.selectRestrict(type.valueType());
        Iterator<Tensor.Cell> i = tensor.cellIterator();
        while (i.hasNext()) {
            Tensor.Cell cell = i.next();
            builder.cell(cell.getKey(), restrict.apply(Float.valueOf((float)cell.getDoubleValue())).floatValue());
        }
        return builder.build();
    }

    private Tensor castFromSomeFloat(Tensor tensor, TensorType type) {
        Tensor.Builder builder = Tensor.Builder.of(type);
        Function<Float, Float> restrict = CellCast.selectRestrict(type.valueType());
        Iterator<Tensor.Cell> i = tensor.cellIterator();
        while (i.hasNext()) {
            Tensor.Cell cell = i.next();
            builder.cell(cell.getKey(), restrict.apply(Float.valueOf(cell.getFloatValue())).floatValue());
        }
        return builder.build();
    }

    private static Function<Float, Float> selectRestrict(TensorType.Value toValueType) {
        return switch (toValueType) {
            case TensorType.Value.BFLOAT16 -> val -> Float.valueOf(Float.intBitsToFloat(Float.floatToRawIntBits(val.floatValue()) & 0xFFFF0000));
            case TensorType.Value.INT8 -> val -> Float.valueOf(val.byteValue());
            default -> val -> val;
        };
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "cell_cast(" + this.argument.toString(context) + ", " + String.valueOf((Object)this.valueType) + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash(new Object[]{"cellcast", this.argument, this.valueType});
    }
}

