/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.messages.intercom;

import java.util.Arrays;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage;
import org.nd4j.parameterserver.distributed.messages.BaseVoidMessage;
import org.nd4j.parameterserver.distributed.messages.DistributedMessage;
import org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation;
import org.nd4j.parameterserver.distributed.messages.requests.CbowRequestMessage;
import org.nd4j.parameterserver.distributed.training.impl.CbowTrainer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public class DistributedCbowDotMessage
extends BaseVoidMessage
implements DistributedMessage {
    private static final Logger log = LoggerFactory.getLogger(DistributedCbowDotMessage.class);
    protected int[] rowsA;
    protected int[] rowsB;
    protected int w1;
    protected boolean useHS;
    protected short negSamples;
    protected float alpha;
    protected byte[] codes;

    public DistributedCbowDotMessage() {
        this.messageType = 22;
    }

    @Deprecated
    public DistributedCbowDotMessage(long taskId, int rowA, int rowB) {
        this(taskId, new int[]{rowA}, new int[]{rowB}, rowA, new byte[0], false, 0, 0.001f);
    }

    public DistributedCbowDotMessage(long taskId, @NonNull int[] rowsA, @NonNull int[] rowsB, int w1, @NonNull byte[] codes, boolean useHS, short negSamples, float alpha) {
        this();
        if (rowsA == null) {
            throw new NullPointerException("rowsA is marked non-null but is null");
        }
        if (rowsB == null) {
            throw new NullPointerException("rowsB is marked non-null but is null");
        }
        if (codes == null) {
            throw new NullPointerException("codes is marked non-null but is null");
        }
        this.rowsA = rowsA;
        this.rowsB = rowsB;
        this.taskId = taskId;
        this.w1 = w1;
        this.useHS = useHS;
        this.negSamples = negSamples;
        this.alpha = alpha;
        this.codes = codes;
    }

    @Override
    public void processMessage() {
        double dot;
        int e;
        CbowRequestMessage cbrm = new CbowRequestMessage(this.rowsA, this.rowsB, this.w1, this.codes, this.negSamples, this.alpha, 119L);
        if (this.negSamples > 0) {
            int[] negatives = Arrays.copyOfRange(this.rowsB, this.codes.length, this.rowsB.length);
            cbrm.setNegatives(negatives);
        }
        cbrm.setFrameId(-119L);
        cbrm.setTaskId(this.taskId);
        cbrm.setOriginatorId(this.getOriginatorId());
        CbowTrainer cbt = (CbowTrainer)this.trainer;
        cbt.pickTraining(cbrm);
        INDArray words = Nd4j.pullRows((INDArray)this.storage.getArray(WordVectorStorage.SYN_0), (int)1, (int[])this.rowsA, (char)'c');
        INDArray mean = words.mean(new int[]{0});
        int resultLength = this.codes.length + (this.negSamples > 0 ? this.negSamples + 1 : 0);
        INDArray result = Nd4j.createUninitialized((int[])new int[]{resultLength, 1});
        for (e = 0; e < this.codes.length; ++e) {
            dot = Nd4j.getBlasWrapper().dot(mean, this.storage.getArray(WordVectorStorage.SYN_1).getRow((long)this.rowsB[e]));
            result.putScalar((long)e, dot);
        }
        while (e < resultLength) {
            dot = Nd4j.getBlasWrapper().dot(mean, this.storage.getArray(WordVectorStorage.SYN_1_NEGATIVE).getRow((long)this.rowsB[e]));
            result.putScalar((long)e, dot);
            ++e;
        }
        if (this.voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
            DotAggregation dot2 = new DotAggregation(this.taskId, 1, this.shardIndex, result);
            dot2.setTargetId((short)-1);
            dot2.setOriginatorId(this.getOriginatorId());
            this.transport.putMessage(dot2);
        } else if (this.voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
            DotAggregation dot3 = new DotAggregation(this.taskId, (short)this.voidConfiguration.getNumberOfShards(), this.shardIndex, result);
            dot3.setTargetId((short)-1);
            dot3.setOriginatorId(this.getOriginatorId());
            this.transport.sendMessage(dot3);
        }
    }

    public int[] getRowsA() {
        return this.rowsA;
    }

    public int[] getRowsB() {
        return this.rowsB;
    }

    public int getW1() {
        return this.w1;
    }

    public boolean isUseHS() {
        return this.useHS;
    }

    public short getNegSamples() {
        return this.negSamples;
    }

    public float getAlpha() {
        return this.alpha;
    }

    public byte[] getCodes() {
        return this.codes;
    }

    public void setRowsA(int[] rowsA) {
        this.rowsA = rowsA;
    }

    public void setRowsB(int[] rowsB) {
        this.rowsB = rowsB;
    }

    public void setW1(int w1) {
        this.w1 = w1;
    }

    public void setUseHS(boolean useHS) {
        this.useHS = useHS;
    }

    public void setNegSamples(short negSamples) {
        this.negSamples = negSamples;
    }

    public void setAlpha(float alpha) {
        this.alpha = alpha;
    }

    public void setCodes(byte[] codes) {
        this.codes = codes;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof DistributedCbowDotMessage)) {
            return false;
        }
        DistributedCbowDotMessage other = (DistributedCbowDotMessage)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getW1() != other.getW1()) {
            return false;
        }
        if (this.isUseHS() != other.isUseHS()) {
            return false;
        }
        if (this.getNegSamples() != other.getNegSamples()) {
            return false;
        }
        if (Float.compare(this.getAlpha(), other.getAlpha()) != 0) {
            return false;
        }
        if (!Arrays.equals(this.getRowsA(), other.getRowsA())) {
            return false;
        }
        if (!Arrays.equals(this.getRowsB(), other.getRowsB())) {
            return false;
        }
        return Arrays.equals(this.getCodes(), other.getCodes());
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof DistributedCbowDotMessage;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getW1();
        result = result * 59 + (this.isUseHS() ? 79 : 97);
        result = result * 59 + this.getNegSamples();
        result = result * 59 + Float.floatToIntBits(this.getAlpha());
        result = result * 59 + Arrays.hashCode(this.getRowsA());
        result = result * 59 + Arrays.hashCode(this.getRowsB());
        result = result * 59 + Arrays.hashCode(this.getCodes());
        return result;
    }

    @Override
    public String toString() {
        return "DistributedCbowDotMessage(rowsA=" + Arrays.toString(this.getRowsA()) + ", rowsB=" + Arrays.toString(this.getRowsB()) + ", w1=" + this.getW1() + ", useHS=" + this.isUseHS() + ", negSamples=" + this.getNegSamples() + ", alpha=" + this.getAlpha() + ", codes=" + Arrays.toString(this.getCodes()) + ")";
    }
}

