/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.serialization;

import com.yahoo.lang.MutableInteger;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Inserter;
import com.yahoo.slime.Inspector;
import com.yahoo.slime.JsonDecoder;
import com.yahoo.slime.ObjectInserter;
import com.yahoo.slime.Slime;
import com.yahoo.slime.SlimeInserter;
import com.yahoo.slime.Type;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;

public class JsonFormat {
    private static final char[] hexDigits = new char[]{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};

    public static byte[] encode(Tensor tensor, boolean shortForm, boolean directValues) {
        return JsonFormat.encode(tensor, new EncodeOptions(shortForm, directValues, false));
    }

    /*
     * Enabled aggressive block sorting
     */
    public static byte[] encode(Tensor tensor, EncodeOptions options) {
        Cursor root;
        Slime slime = new Slime();
        Function<String, Inserter> target = key -> new SlimeInserter(slime);
        Cursor cursor = root = options.directValues() ? null : slime.setObject();
        if (!options.directValues()) {
            root.setString("type", tensor.type().toString());
            target = key -> new ObjectInserter(root, (String)key);
        }
        if (!options.shortForm()) {
            Cursor parent = target.apply("cells").insertARRAY();
            JsonFormat.encodeCells(tensor, parent);
            return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
        }
        if (tensor instanceof IndexedTensor) {
            IndexedTensor denseTensor = (IndexedTensor)tensor;
            if (options.hexForDensePart()) {
                target.apply("values").insertSTRING(JsonFormat.asHexString(denseTensor));
                return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
            }
            Cursor parent = target.apply("values").insertARRAY();
            JsonFormat.encodeDenseValues(denseTensor, parent);
            return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
        }
        if (tensor instanceof MappedTensor) {
            MappedTensor mapped = (MappedTensor)tensor;
            if (tensor.type().dimensions().size() == 1) {
                Cursor parent = target.apply("cells").insertOBJECT();
                JsonFormat.encodeSingleDimensionCells(mapped, parent);
                return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
            }
        }
        if (tensor instanceof MixedTensor) {
            MixedTensor mixed = (MixedTensor)tensor;
            if (tensor.type().hasMappedDimensions()) {
                boolean singleMapped;
                boolean bl = singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1L;
                if (singleMapped) {
                    JsonFormat.encodeLabeledBlocks(mixed, target.apply("blocks").insertOBJECT(), options.hexForDensePart());
                    return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
                }
                JsonFormat.encodeAddressedBlocks(mixed, target.apply("blocks").insertARRAY(), options.hexForDensePart());
                return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
            }
        }
        Cursor parent = target.apply("cells").insertARRAY();
        JsonFormat.encodeCells(tensor, parent);
        return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
    }

    public static byte[] encode(Tensor tensor) {
        return JsonFormat.encode(tensor, false, false);
    }

    @Deprecated
    public static byte[] encodeWithType(Tensor tensor) {
        return JsonFormat.encode(tensor, false, false);
    }

    @Deprecated
    public static byte[] encodeShortForm(Tensor tensor) {
        return JsonFormat.encode(tensor, true, false);
    }

    private static void encodeCells(Tensor tensor, Cursor cellsArray) {
        Iterator<Tensor.Cell> i = tensor.cellIterator();
        while (i.hasNext()) {
            Tensor.Cell cell = i.next();
            Cursor cellObject = cellsArray.addObject();
            JsonFormat.encodeAddress(tensor.type(), cell.getKey(), cellObject.setObject("address"));
            JsonFormat.setValue("value", cell.getValue(), tensor.type().valueType(), cellObject);
        }
    }

    private static void encodeSingleDimensionCells(MappedTensor tensor, Cursor cells) {
        if (tensor.type().dimensions().size() > 1) {
            throw new IllegalStateException("JSON encode of mapped tensor can only contain a single dimension");
        }
        tensor.cells().forEach((k, v) -> JsonFormat.setValue(k.label(0), v, tensor.type().valueType(), cells));
    }

    private static void encodeAddress(TensorType type, TensorAddress address, Cursor addressObject) {
        for (int i = 0; i < address.size(); ++i) {
            addressObject.setString(type.dimensions().get(i).name(), address.label(i));
        }
    }

    private static String asHexString(IndexedTensor tensor) {
        return JsonFormat.asHexString(tensor.sizeAsInt(), tensor.type().valueType(), i -> tensor.get((long)i.intValue()), i -> Float.valueOf(tensor.getFloat((long)i.intValue())));
    }

    private static String asHexString(int denseSize, TensorType.Value cellType, Function<Integer, Double> dblSrc, Function<Integer, Float> fltSrc) {
        StringBuilder buf = new StringBuilder();
        switch (cellType) {
            case DOUBLE: {
                for (int i = 0; i < denseSize; ++i) {
                    double d = dblSrc.apply(i);
                    long bits = Double.doubleToRawLongBits(d);
                    int nibble = 16;
                    while (nibble-- > 0) {
                        int digit = (int)(bits >> 4 * nibble) & 0xF;
                        buf.append(hexDigits[digit]);
                    }
                }
                break;
            }
            case FLOAT: {
                for (int i = 0; i < denseSize; ++i) {
                    float f = fltSrc.apply(i).floatValue();
                    int bits = Float.floatToRawIntBits(f);
                    int nibble = 8;
                    while (nibble-- > 0) {
                        int digit = bits >> 4 * nibble & 0xF;
                        buf.append(hexDigits[digit]);
                    }
                }
                break;
            }
            case BFLOAT16: {
                for (int i = 0; i < denseSize; ++i) {
                    float f = fltSrc.apply(i).floatValue();
                    int bits = Float.floatToRawIntBits(f);
                    int nibble = 8;
                    while (nibble-- > 4) {
                        int digit = bits >> 4 * nibble & 0xF;
                        buf.append(hexDigits[digit]);
                    }
                }
                break;
            }
            case INT8: {
                for (int i = 0; i < denseSize; ++i) {
                    byte bits = fltSrc.apply(i).byteValue();
                    int nibble = 2;
                    while (nibble-- > 0) {
                        int digit = bits >> 4 * nibble & 0xF;
                        buf.append(hexDigits[digit]);
                    }
                }
                break;
            }
        }
        return buf.toString();
    }

    private static void encodeDenseValues(IndexedTensor tensor, Cursor target) {
        JsonFormat.encodeValues(tensor, target, new long[tensor.dimensionSizes().dimensions()], 0);
    }

    private static void encodeValues(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
        DimensionSizes sizes = tensor.dimensionSizes();
        if (indexes.length == 0) {
            JsonFormat.addValue(tensor.get(0L), tensor.type().valueType(), cursor);
        } else {
            indexes[dimension] = 0L;
            while (indexes[dimension] < sizes.size(dimension)) {
                if (dimension < sizes.dimensions() - 1) {
                    JsonFormat.encodeValues(tensor, cursor.addArray(), indexes, dimension + 1);
                } else {
                    JsonFormat.addValue(tensor.get(indexes), tensor.type().valueType(), cursor);
                }
                int n = dimension;
                indexes[n] = indexes[n] + 1L;
            }
        }
    }

    private static void encodeLabeledSubspace(String label, MixedTensor.DenseSubspace subspace, TensorType denseSubType, Cursor cursor, boolean hexForDensePart) {
        if (hexForDensePart) {
            cursor.setString(label, JsonFormat.asHexString(subspace.cells.length, denseSubType.valueType(), i -> subspace.cells[i], i -> Float.valueOf((float)subspace.cells[i])));
        } else {
            IndexedTensor denseSubspace = IndexedTensor.Builder.of(denseSubType, subspace.cells).build();
            Cursor target = cursor.setArray(label);
            JsonFormat.encodeDenseValues(denseSubspace, target);
        }
    }

    private static void encodeLabeledBlocks(MixedTensor tensor, Cursor cursor, boolean hexForDensePart) {
        TensorType denseSubType = tensor.type().indexedSubtype();
        for (MixedTensor.DenseSubspace subspace : tensor.getInternalDenseSubspaces()) {
            String label = subspace.sparseAddress.label(0);
            JsonFormat.encodeLabeledSubspace(label, subspace, denseSubType, cursor, hexForDensePart);
        }
    }

    private static void encodeAddressedBlocks(MixedTensor tensor, Cursor cursor, boolean hexForDensePart) {
        List<TensorType.Dimension> mappedDimensions = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).toList();
        if (mappedDimensions.isEmpty()) {
            throw new IllegalArgumentException("Should be ensured by caller");
        }
        TensorType mappedSubType = new TensorType.Builder(mappedDimensions).build();
        TensorType denseSubType = tensor.type().indexedSubtype();
        for (MixedTensor.DenseSubspace subspace : tensor.getInternalDenseSubspaces()) {
            Cursor block = cursor.addObject();
            JsonFormat.encodeAddress(mappedSubType, subspace.sparseAddress, block.setObject("address"));
            JsonFormat.encodeLabeledSubspace("values", subspace, denseSubType, block, hexForDensePart);
        }
    }

    private static void addValue(double value, TensorType.Value valueType, Cursor cursor) {
        if (valueType == TensorType.Value.INT8) {
            cursor.addLong((long)value);
        } else {
            cursor.addDouble(value);
        }
    }

    private static void setValue(String field, double value, TensorType.Value valueType, Cursor cursor) {
        if (valueType == TensorType.Value.INT8) {
            cursor.setLong(field, (long)value);
        } else {
            cursor.setDouble(field, value);
        }
    }

    public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
        Tensor.Builder builder = Tensor.Builder.of(type);
        Cursor root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
        if (root.field("cells").valid() && !JsonFormat.primitiveContent(root.field("cells"))) {
            JsonFormat.decodeCells(root.field("cells"), builder);
        } else if (root.field("values").valid() && !builder.type().hasMappedDimensions()) {
            JsonFormat.decodeValuesAtTop(root.field("values"), builder);
        } else if (root.field("blocks").valid()) {
            JsonFormat.decodeBlocks(root.field("blocks"), builder);
        } else {
            JsonFormat.decodeDirectValue(root, builder);
        }
        return builder.build();
    }

    private static boolean primitiveContent(Inspector cellsValue) {
        if (cellsValue.type() == Type.DOUBLE) {
            return true;
        }
        if (cellsValue.type() == Type.LONG) {
            return true;
        }
        return cellsValue.type() == Type.ARRAY && cellsValue.entries() > 0 && (cellsValue.entry(0).type() == Type.DOUBLE || cellsValue.entry(0).type() == Type.LONG);
    }

    private static void decodeCells(Inspector cells, Tensor.Builder builder) {
        if (cells.type() == Type.ARRAY) {
            cells.traverse((__, cell) -> JsonFormat.decodeCell(cell, builder));
        } else if (cells.type() == Type.OBJECT) {
            cells.traverse((key, value) -> JsonFormat.decodeSingleDimensionCell(key, value, builder));
        } else {
            throw new IllegalArgumentException("Excepted 'cells' to contain an array or object, not " + String.valueOf((Object)cells.type()));
        }
    }

    private static void decodeCell(Inspector cell, Tensor.Builder builder) {
        TensorAddress address = JsonFormat.decodeAddress(cell.field("address"), builder.type());
        Inspector value = cell.field("value");
        if (!value.valid()) {
            throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'");
        }
        builder.cell(address, JsonFormat.decodeNumeric(value));
    }

    private static void decodeSingleDimensionCell(String key, Inspector value, Tensor.Builder builder) {
        builder.cell(JsonFormat.asAddress(key, builder.type()), JsonFormat.decodeNumeric(value));
    }

    private static void decodeValuesAtTop(Inspector values, Tensor.Builder builder) {
        JsonFormat.decodeNestedValues(values, builder, new MutableInteger(0));
    }

    private static void decodeNestedValues(Inspector values, Tensor.Builder builder, MutableInteger index) {
        if (!(builder instanceof IndexedTensor.BoundBuilder)) {
            throw new IllegalArgumentException("An array of values can only be used with a dense tensor. Use a map instead");
        }
        IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
        if (values.type() == Type.STRING) {
            double[] decoded = JsonFormat.decodeHexString(values.asString(), builder.type().valueType());
            if (decoded.length == 0) {
                throw new IllegalArgumentException("The values string does not contain any values");
            }
            for (int i = 0; i < decoded.length; ++i) {
                indexedBuilder.cellByDirectIndex((long)i, decoded[i]);
            }
            return;
        }
        if (values.type() != Type.ARRAY) {
            throw new IllegalArgumentException("Excepted values to be an array, not " + String.valueOf((Object)values.type()));
        }
        if (values.entries() == 0) {
            throw new IllegalArgumentException("The values array does not contain any values");
        }
        values.traverse((__, value) -> {
            if (value.type() == Type.ARRAY) {
                JsonFormat.decodeNestedValues(value, builder, index);
            } else if (value.type() == Type.LONG || value.type() == Type.DOUBLE || value.type() == Type.STRING || value.type() == Type.NIX) {
                indexedBuilder.cellByDirectIndex((long)index.next(), JsonFormat.decodeNumeric(value));
            } else {
                throw new IllegalArgumentException("Excepted the values array to contain numbers or nested arrays, not " + String.valueOf((Object)value.type()));
            }
        });
    }

    private static void decodeBlocks(Inspector values, Tensor.Builder builder) {
        if (!(builder instanceof MixedTensor.BoundBuilder)) {
            throw new IllegalArgumentException("Blocks of values can only be used with mixed (sparse and dense) tensors.Use an array of cell values instead.");
        }
        MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder)builder;
        if (values.type() == Type.ARRAY) {
            values.traverse((__, value) -> JsonFormat.decodeBlock(value, mixedBuilder));
        } else if (values.type() == Type.OBJECT) {
            values.traverse((key, value) -> JsonFormat.decodeSingleDimensionBlock(key, value, mixedBuilder));
        } else {
            throw new IllegalArgumentException("Excepted the block to contain an array or object, not " + String.valueOf((Object)values.type()));
        }
    }

    private static void decodeBlock(Inspector block, MixedTensor.BoundBuilder mixedBuilder) {
        if (block.type() != Type.OBJECT) {
            throw new IllegalArgumentException("Expected an item in a blocks array to be an object, not " + String.valueOf((Object)block.type()));
        }
        mixedBuilder.block(JsonFormat.decodeAddress(block.field("address"), mixedBuilder.type().mappedSubtype()), JsonFormat.decodeValuesInBlock(block.field("values"), mixedBuilder));
    }

    private static void decodeDirectValue(Inspector root, Tensor.Builder builder) {
        boolean hasIndexed = builder.type().hasIndexedDimensions();
        boolean hasMapped = builder.type().hasMappedDimensions();
        if (JsonFormat.isArrayOfObjects(root)) {
            JsonFormat.decodeCells(root, builder);
        } else if (!hasMapped) {
            JsonFormat.decodeValuesAtTop(root, builder);
        } else if (hasIndexed) {
            JsonFormat.decodeBlocks(root, builder);
        } else {
            JsonFormat.decodeCells(root, builder);
        }
    }

    private static boolean isArrayOfObjects(Inspector inspector) {
        if (inspector.type() != Type.ARRAY) {
            return false;
        }
        if (inspector.entries() == 0) {
            return false;
        }
        Inspector firstItem = inspector.entry(0);
        if (firstItem.type() == Type.ARRAY) {
            return JsonFormat.isArrayOfObjects(firstItem);
        }
        return firstItem.type() == Type.OBJECT;
    }

    private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) {
        if (value.type() != Type.ARRAY && value.type() != Type.STRING) {
            throw new IllegalArgumentException("Expected an item in a blocks array to be an array, not " + String.valueOf((Object)value.type()));
        }
        mixedBuilder.block(JsonFormat.asAddress(key, mixedBuilder.type().mappedSubtype()), JsonFormat.decodeValuesInBlock(value, mixedBuilder));
    }

    private static byte decodeHex(String input, int index) {
        int d = Character.digit(input.charAt(index), 16);
        if (d < 0) {
            throw new IllegalArgumentException("Invalid digit '" + input.charAt(index) + "' at index " + index + " in input " + input);
        }
        return (byte)d;
    }

    private static double[] decodeHexStringAsBytes(String input) {
        int l = input.length() / 2;
        double[] result = new double[l];
        int idx = 0;
        for (int i = 0; i < l; ++i) {
            byte v = JsonFormat.decodeHex(input, idx++);
            v = (byte)(v << 4);
            v = (byte)(v + JsonFormat.decodeHex(input, idx++));
            result[i] = v;
        }
        return result;
    }

    private static double[] decodeHexStringAsBFloat16s(String input) {
        int l = input.length() / 4;
        double[] result = new double[l];
        int idx = 0;
        for (int i = 0; i < l; ++i) {
            int v = JsonFormat.decodeHex(input, idx++);
            v <<= 4;
            v += JsonFormat.decodeHex(input, idx++);
            v <<= 4;
            v += JsonFormat.decodeHex(input, idx++);
            v <<= 4;
            v += JsonFormat.decodeHex(input, idx++);
            result[i] = Float.intBitsToFloat(v <<= 16);
        }
        return result;
    }

    private static double[] decodeHexStringAsFloats(String input) {
        int l = input.length() / 8;
        double[] result = new double[l];
        int idx = 0;
        for (int i = 0; i < l; ++i) {
            int v = 0;
            for (int j = 0; j < 8; ++j) {
                v <<= 4;
                v += JsonFormat.decodeHex(input, idx++);
            }
            result[i] = Float.intBitsToFloat(v);
        }
        return result;
    }

    private static double[] decodeHexStringAsDoubles(String input) {
        int l = input.length() / 16;
        double[] result = new double[l];
        int idx = 0;
        for (int i = 0; i < l; ++i) {
            long v = 0L;
            for (int j = 0; j < 16; ++j) {
                v <<= 4;
                v += (long)JsonFormat.decodeHex(input, idx++);
            }
            result[i] = Double.longBitsToDouble(v);
        }
        return result;
    }

    public static double[] decodeHexString(String input, TensorType.Value valueType) {
        return switch (valueType) {
            default -> throw new IncompatibleClassChangeError();
            case TensorType.Value.INT8 -> JsonFormat.decodeHexStringAsBytes(input);
            case TensorType.Value.BFLOAT16 -> JsonFormat.decodeHexStringAsBFloat16s(input);
            case TensorType.Value.FLOAT -> JsonFormat.decodeHexStringAsFloats(input);
            case TensorType.Value.DOUBLE -> JsonFormat.decodeHexStringAsDoubles(input);
        };
    }

    private static void decodeMaybeNestedValuesInBlock(Inspector arrayField, double[] target, MutableInteger index) {
        if (arrayField.entries() == 0) {
            throw new IllegalArgumentException("The block value array does not contain any values");
        }
        arrayField.traverse((__, value) -> {
            if (value.type() == Type.ARRAY) {
                JsonFormat.decodeMaybeNestedValuesInBlock(value, target, index);
            } else {
                target[index.next()] = JsonFormat.decodeNumeric(value);
            }
        });
    }

    private static double[] decodeValuesInBlock(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
        double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
        if (valuesField.type() == Type.ARRAY) {
            JsonFormat.decodeMaybeNestedValuesInBlock(valuesField, values, new MutableInteger(0));
        } else if (valuesField.type() == Type.STRING) {
            double[] decoded = JsonFormat.decodeHexString(valuesField.asString(), mixedBuilder.type().valueType());
            if (decoded.length == 0) {
                throw new IllegalArgumentException("The block value string does not contain any values");
            }
            System.arraycopy(decoded, 0, values, 0, decoded.length);
        } else {
            throw new IllegalArgumentException("Expected a block to contain an array of values");
        }
        return values;
    }

    private static TensorAddress decodeAddress(Inspector addressField, TensorType type) {
        if (addressField.type() != Type.OBJECT) {
            throw new IllegalArgumentException("Expected an 'address' object, not " + String.valueOf((Object)addressField.type()));
        }
        TensorAddress.Builder builder = new TensorAddress.Builder(type);
        addressField.traverse((dimension, label) -> builder.add(dimension, label.asString()));
        return builder.build();
    }

    private static TensorAddress asAddress(String label, TensorType type) {
        if (type.dimensions().size() != 1) {
            throw new IllegalArgumentException("Expected a tensor with a single dimension but got " + String.valueOf(type));
        }
        return new TensorAddress.Builder(type).add(type.dimensions().get(0).name(), label).build();
    }

    private static double decodeNumeric(Inspector numericField) {
        if (numericField.type() == Type.DOUBLE || numericField.type() == Type.LONG) {
            return numericField.asDouble();
        }
        if (numericField.type() == Type.STRING) {
            return JsonFormat.decodeNumberString(numericField.asString());
        }
        if (numericField.type() == Type.NIX) {
            return Double.NaN;
        }
        throw new IllegalArgumentException("Excepted a number, not " + String.valueOf((Object)numericField.type()));
    }

    public static double decodeNumberString(String input) {
        String s = input.toLowerCase();
        if (s.equals("infinity") || s.equals("+infinity") || s.equals("inf") || s.equals("+inf")) {
            return Double.POSITIVE_INFINITY;
        }
        if (s.equals("-infinity") || s.equals("-inf")) {
            return Double.NEGATIVE_INFINITY;
        }
        if (s.equals("nan") || s.equals("+nan")) {
            return Double.NaN;
        }
        if (s.equals("-nan")) {
            return Math.copySign(Double.NaN, -1.0);
        }
        return Double.parseDouble(input);
    }

    public record EncodeOptions(boolean shortForm, boolean directValues, boolean hexForDensePart) {
        public EncodeOptions() {
            this(false);
        }

        public EncodeOptions(boolean shortForm) {
            this(shortForm, false);
        }

        public EncodeOptions(boolean shortForm, boolean directValues) {
            this(shortForm, directValues, false);
        }
    }
}

