/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.java;

import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.java.XGBoostJNI;

public class Communicator {
    public static Map<String, String> communicatorEnvs;
    public static List<String> mockList;

    private static void checkCall(int ret) throws XGBoostError {
        if (ret != 0) {
            throw new XGBoostError(XGBoostJNI.XGBGetLastError());
        }
    }

    public static void init(Map<String, String> envs) throws XGBoostError {
        communicatorEnvs = envs;
        String[] args = new String[envs.size() * 2 + mockList.size() * 2];
        int idx = 0;
        for (Map.Entry<String, String> e : envs.entrySet()) {
            args[idx++] = e.getKey();
            args[idx++] = e.getValue();
        }
        for (String mock : mockList) {
            args[idx++] = "mock";
            args[idx++] = mock;
        }
        Communicator.checkCall(XGBoostJNI.CommunicatorInit(args));
    }

    public static void shutdown() throws XGBoostError {
        Communicator.checkCall(XGBoostJNI.CommunicatorFinalize());
    }

    public static void communicatorPrint(String msg) throws XGBoostError {
        Communicator.checkCall(XGBoostJNI.CommunicatorPrint(msg));
    }

    public static int getRank() throws XGBoostError {
        int[] out = new int[1];
        Communicator.checkCall(XGBoostJNI.CommunicatorGetRank(out));
        return out[0];
    }

    public static int getWorldSize() throws XGBoostError {
        int[] out = new int[1];
        Communicator.checkCall(XGBoostJNI.CommunicatorGetWorldSize(out));
        return out[0];
    }

    public static float[] allReduce(float[] elements, OpType op) {
        DataType dataType = DataType.FLOAT32;
        ByteBuffer buffer = ByteBuffer.allocateDirect(dataType.getSize() * elements.length).order(ByteOrder.nativeOrder());
        for (float el : elements) {
            buffer.putFloat(el);
        }
        buffer.flip();
        XGBoostJNI.CommunicatorAllreduce(buffer, elements.length, dataType.getEnumOp(), op.getOperand());
        float[] results = new float[elements.length];
        buffer.asFloatBuffer().get(results);
        return results;
    }

    static {
        mockList = new LinkedList<String>();
    }

    public static enum DataType implements Serializable
    {
        INT8(0, 1),
        UINT8(1, 1),
        INT32(2, 4),
        UINT32(3, 4),
        INT64(4, 8),
        UINT64(5, 8),
        FLOAT32(6, 4),
        FLOAT64(7, 8);

        private final int enumOp;
        private final int size;

        public int getEnumOp() {
            return this.enumOp;
        }

        public int getSize() {
            return this.size;
        }

        private DataType(int enumOp, int size) {
            this.enumOp = enumOp;
            this.size = size;
        }
    }

    public static enum OpType implements Serializable
    {
        MAX(0),
        MIN(1),
        SUM(2);

        private int op;

        public int getOperand() {
            return this.op;
        }

        private OpType(int op) {
            this.op = op;
        }
    }
}

