/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor;

import com.yahoo.data.disclosure.DataSink;
import com.yahoo.data.disclosure.DataSource;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;

public class TensorDataSource
implements DataSource {
    private final Tensor tensor;
    private final JsonFormat.EncodeOptions options;
    private final boolean wrapAndType;
    private boolean inObject = false;
    private static final byte[] hexDigits = "0123456789ABCDEF".getBytes(StandardCharsets.US_ASCII);

    public TensorDataSource(Tensor tensor, JsonFormat.EncodeOptions options) {
        this.tensor = tensor;
        this.options = options;
        this.wrapAndType = !options.directValues();
    }

    @Override
    public void emit(DataSink sink) {
        this.wrapStart(sink);
        if (this.wrapAndType) {
            sink.fieldName("type");
            sink.stringValue(this.tensor.type().toString());
        }
        if (this.options.shortForm()) {
            this.emitShortForm(sink);
        } else {
            this.emitLongForm(sink);
        }
        this.ensureObjectEnded(sink);
    }

    /*
     * Enabled aggressive block sorting
     */
    private void emitShortForm(DataSink sink) {
        Tensor tensor = this.tensor;
        if (tensor instanceof IndexedTensor) {
            IndexedTensor denseTensor = (IndexedTensor)tensor;
            this.startField("values", sink);
            if (this.options.hexForDensePart()) {
                sink.stringValue(TensorDataSource.asHexString(denseTensor));
                return;
            }
            this.emitDenseValues(denseTensor, sink);
            return;
        }
        tensor = this.tensor;
        if (tensor instanceof MappedTensor) {
            MappedTensor mapped = (MappedTensor)tensor;
            if (this.tensor.type().dimensions().size() == 1) {
                this.startField("cells", sink);
                this.emitSingleDimensionCells(mapped, sink);
                return;
            }
        }
        if ((tensor = this.tensor) instanceof MixedTensor) {
            MixedTensor mixed = (MixedTensor)tensor;
            if (this.tensor.type().hasMappedDimensions()) {
                boolean singleMapped;
                boolean bl = singleMapped = this.tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1L;
                if (singleMapped) {
                    this.startField("blocks", sink);
                    this.emitLabeledBlocks(mixed, sink);
                    return;
                }
                boolean startedObject = this.ensureObject(sink);
                this.startField("blocks", sink);
                this.emitAddressedBlocks(mixed, sink);
                if (!startedObject) return;
                this.ensureObjectEnded(sink);
                return;
            }
        }
        this.startField("cells", sink);
        this.emitCells(sink);
    }

    private void emitLongForm(DataSink sink) {
        this.startField("cells", sink);
        this.emitCells(sink);
    }

    private void emitCells(DataSink sink) {
        sink.startArray();
        Iterator<Tensor.Cell> i = this.tensor.cellIterator();
        while (i.hasNext()) {
            Tensor.Cell cell = i.next();
            sink.startObject();
            sink.fieldName("address");
            sink.startObject();
            TensorAddress address = cell.getKey();
            for (int j = 0; j < address.size(); ++j) {
                sink.fieldName(this.tensor.type().dimensions().get(j).name());
                sink.stringValue(address.label(j));
            }
            sink.endObject();
            sink.fieldName("value");
            this.emitValue(cell.getValue(), this.tensor.type().valueType(), sink);
            sink.endObject();
        }
        sink.endArray();
    }

    private void emitSingleDimensionCells(MappedTensor tensor, DataSink sink) {
        if (tensor.type().dimensions().size() != 1) {
            throw new IllegalStateException("Single dimension encoding requires exactly one dimension");
        }
        sink.startObject();
        tensor.cells().forEach((address, value) -> {
            sink.fieldName(address.label(0));
            this.emitValue((double)value, tensor.type().valueType(), sink);
        });
        sink.endObject();
    }

    private void emitDenseValues(IndexedTensor tensor, DataSink sink) {
        sink.startArray();
        this.emitDenseValuesRecursive(tensor, sink, new long[tensor.dimensionSizes().dimensions()], 0);
        sink.endArray();
    }

    private void emitDenseValuesRecursive(IndexedTensor tensor, DataSink sink, long[] indexes, int dimension) {
        DimensionSizes sizes = tensor.dimensionSizes();
        if (indexes.length == 0) {
            this.emitValue(tensor.get(0L), tensor.type().valueType(), sink);
        } else {
            indexes[dimension] = 0L;
            while (indexes[dimension] < sizes.size(dimension)) {
                if (dimension < sizes.dimensions() - 1) {
                    sink.startArray();
                    this.emitDenseValuesRecursive(tensor, sink, indexes, dimension + 1);
                    sink.endArray();
                } else {
                    this.emitValue(tensor.get(indexes), tensor.type().valueType(), sink);
                }
                int n = dimension;
                indexes[n] = indexes[n] + 1L;
            }
        }
    }

    private void emitLabeledBlocks(MixedTensor tensor, DataSink sink) {
        sink.startObject();
        TensorType denseSubType = tensor.type().indexedSubtype();
        for (MixedTensor.DenseSubspace subspace : tensor.getInternalDenseSubspaces()) {
            String label = subspace.sparseAddress.label(0);
            sink.fieldName(label);
            if (this.options.hexForDensePart()) {
                sink.stringValue(TensorDataSource.asHexString(subspace.cells.length, denseSubType.valueType(), i -> subspace.cells[i], i -> Float.valueOf((float)subspace.cells[i])));
                continue;
            }
            IndexedTensor denseSubspace = IndexedTensor.Builder.of(denseSubType, subspace.cells).build();
            this.emitDenseValues(denseSubspace, sink);
        }
        sink.endObject();
    }

    private void emitAddressedBlocks(MixedTensor tensor, DataSink sink) {
        sink.startArray();
        List<TensorType.Dimension> mappedDimensions = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).toList();
        TensorType denseSubType = tensor.type().indexedSubtype();
        for (MixedTensor.DenseSubspace subspace : tensor.getInternalDenseSubspaces()) {
            sink.startObject();
            sink.fieldName("address");
            sink.startObject();
            for (int i2 = 0; i2 < mappedDimensions.size(); ++i2) {
                sink.fieldName(mappedDimensions.get(i2).name());
                sink.stringValue(subspace.sparseAddress.label(i2));
            }
            sink.endObject();
            sink.fieldName("values");
            if (this.options.hexForDensePart()) {
                sink.stringValue(TensorDataSource.asHexString(subspace.cells.length, denseSubType.valueType(), i -> subspace.cells[i], i -> Float.valueOf((float)subspace.cells[i])));
            } else {
                IndexedTensor denseSubspace = IndexedTensor.Builder.of(denseSubType, subspace.cells).build();
                this.emitDenseValues(denseSubspace, sink);
            }
            sink.endObject();
        }
        sink.endArray();
    }

    private void emitValue(double value, TensorType.Value valueType, DataSink sink) {
        switch (valueType) {
            case DOUBLE: {
                sink.doubleValue(value);
                break;
            }
            case FLOAT: 
            case BFLOAT16: {
                sink.floatValue((float)value);
                break;
            }
            case INT8: {
                sink.byteValue((byte)value);
            }
        }
    }

    private static byte[] asHexString(IndexedTensor tensor) {
        return TensorDataSource.asHexString(tensor.sizeAsInt(), tensor.type().valueType(), i -> tensor.get((long)i.intValue()), i -> Float.valueOf(tensor.getFloat((long)i.intValue())));
    }

    private static byte[] asHexString(int denseSize, TensorType.Value cellType, Function<Integer, Double> dblSrc, Function<Integer, Float> fltSrc) {
        int nibblesPerCell = switch (cellType) {
            default -> throw new IncompatibleClassChangeError();
            case TensorType.Value.DOUBLE -> 16;
            case TensorType.Value.FLOAT -> 8;
            case TensorType.Value.BFLOAT16 -> 4;
            case TensorType.Value.INT8 -> 2;
        };
        byte[] result = new byte[nibblesPerCell * denseSize];
        int idx = 0;
        for (int i = 0; i < denseSize; ++i) {
            long bits = switch (cellType) {
                default -> throw new IncompatibleClassChangeError();
                case TensorType.Value.DOUBLE -> Double.doubleToRawLongBits(dblSrc.apply(i));
                case TensorType.Value.FLOAT -> Float.floatToRawIntBits(fltSrc.apply(i).floatValue());
                case TensorType.Value.BFLOAT16 -> Float.floatToRawIntBits(fltSrc.apply(i).floatValue()) >>> 16;
                case TensorType.Value.INT8 -> fltSrc.apply(i).byteValue();
            };
            int nibble = nibblesPerCell;
            while (nibble-- > 0) {
                int digit = (int)(bits >> 4 * nibble) & 0xF;
                result[idx++] = hexDigits[digit];
            }
        }
        if (idx != result.length) {
            throw new IllegalStateException("Did not fill result[" + result.length + "], final idx=" + idx);
        }
        return result;
    }

    private void wrapStart(DataSink sink) {
        if (this.wrapAndType) {
            this.ensureObject(sink);
        }
    }

    private void ensureObjectEnded(DataSink sink) {
        if (this.inObject) {
            sink.endObject();
            this.inObject = false;
        }
    }

    private void startField(String fieldName, DataSink sink) {
        if (this.inObject) {
            sink.fieldName(fieldName);
        }
    }

    private boolean ensureObject(DataSink sink) {
        if (this.inObject) {
            return false;
        }
        sink.startObject();
        this.inObject = true;
        return true;
    }
}

