/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.DataType;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.Tensors;
import com.yahoo.vespa.indexinglanguage.expressions.ExecutionContext;
import com.yahoo.vespa.indexinglanguage.expressions.Expression;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationContext;
import com.yahoo.vespa.indexinglanguage.expressions.VerificationException;
import java.util.Optional;

public class PackBitsExpression
extends Expression {
    private TensorType outputTensorType;

    public PackBitsExpression() {
        super((DataType)TensorDataType.any());
    }

    @Override
    public DataType setInputType(DataType inputType, VerificationContext context) {
        super.setInputType(inputType, context);
        if (!this.validType(inputType)) {
            throw new VerificationException(this, "Require a tensor with one dense dimension, but got " + inputType.getName());
        }
        this.outputTensorType = this.outputType(((TensorDataType)inputType).getTensorType());
        return new TensorDataType(this.outputTensorType);
    }

    @Override
    public DataType setOutputType(DataType outputType, VerificationContext context) {
        super.setOutputType(outputType, context);
        if (!this.validType(outputType)) {
            throw new VerificationException(this, "Required to produce " + outputType.getName() + " but this produces a tensor with one dense dimension");
        }
        this.outputTensorType = ((TensorDataType)outputType).getTensorType();
        return new TensorDataType(this.inputType(this.outputTensorType));
    }

    private boolean validType(DataType type) {
        if (!(type instanceof TensorDataType)) {
            return false;
        }
        TensorDataType tensorType = (TensorDataType)type;
        return tensorType.getTensorType().indexedSubtype().dimensions().size() == 1;
    }

    @Override
    protected void doVerify(VerificationContext context) {
    }

    @Override
    protected void doExecute(ExecutionContext context) {
        Optional tensor = ((TensorFieldValue)context.getCurrentValue()).getTensor();
        if (tensor.isEmpty()) {
            return;
        }
        Tensor packed = Tensors.packBits((Tensor)((Tensor)tensor.get()));
        context.setCurrentValue((FieldValue)new TensorFieldValue(packed));
    }

    @Override
    public DataType createdOutputType() {
        return new TensorDataType(this.outputTensorType);
    }

    public String toString() {
        return "pack_bits";
    }

    public int hashCode() {
        return this.toString().hashCode();
    }

    public boolean equals(Object o) {
        return o instanceof PackBitsExpression;
    }

    private TensorType inputType(TensorType givenType) {
        TensorType.Builder builder = new TensorType.Builder(TensorType.Value.INT8);
        for (TensorType.Dimension d : givenType.dimensions()) {
            builder.dimension(d.size().isPresent() ? d.withSize((Long)d.size().get() * 8L) : d);
        }
        return builder.build();
    }

    private TensorType outputType(TensorType givenType) {
        TensorType.Builder builder = new TensorType.Builder(TensorType.Value.INT8);
        for (TensorType.Dimension d : givenType.dimensions()) {
            builder.dimension(d.size().isPresent() ? d.withSize((long)((int)Math.ceil((double)((Long)d.size().get()).longValue() / 8.0))) : d);
        }
        return builder.build();
    }
}

