/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.graph;

import java.util.Map;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.base.Preconditions;

public class AttentionVertex
extends SameDiffVertex {
    private long nInKeys = 0L;
    private long nInValues = 0L;
    private long nInQueries = 0L;
    private long nOut = 0L;
    private long headSize = 0L;
    private int nHeads = 1;
    private boolean projectInput;
    protected WeightInit weightInit;
    private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
    private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
    private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
    private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";

    protected AttentionVertex(Builder builder) {
        this.nInKeys = builder.nInKeys;
        this.nInValues = builder.nInValues;
        this.nInQueries = builder.nInQueries;
        this.nOut = builder.nOut;
        this.headSize = builder.headSize;
        this.projectInput = builder.projectInput;
        this.nHeads = builder.nHeads;
        this.weightInit = builder.weightInit;
    }

    @Override
    public AttentionVertex clone() {
        AttentionVertex av = new AttentionVertex();
        av.nInKeys = this.nInKeys;
        av.nInValues = this.nInValues;
        av.nInQueries = this.nInQueries;
        av.nOut = this.nOut;
        av.headSize = this.headSize;
        av.nHeads = this.nHeads;
        av.projectInput = this.projectInput;
        av.weightInit = this.weightInit;
        return av;
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType ... vertexInputs) throws InvalidInputTypeException {
        InputType.InputTypeRecurrent queries = (InputType.InputTypeRecurrent)vertexInputs[0];
        if (this.projectInput) {
            return InputType.recurrent(this.nOut, queries.getTimeSeriesLength());
        }
        return InputType.recurrent(this.nInValues, queries.getTimeSeriesLength());
    }

    @Override
    public void defineParametersAndInputs(SDVertexParams params) {
        params.clear();
        params.defineInputs("queries", "keys", "values");
        if (this.projectInput) {
            params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, this.nHeads, this.headSize, this.nInQueries);
            params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, this.nHeads, this.headSize, this.nInKeys);
            params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, this.nHeads, this.headSize, this.nInValues);
            params.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, (long)this.nHeads * this.headSize, this.nOut);
        }
    }

    @Override
    public void initializeParameters(Map<String, INDArray> params) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            for (Map.Entry<String, INDArray> e : params.entrySet()) {
                switch (e.getKey()) {
                    case "Wq": {
                        WeightInitUtil.initWeights((double)this.nInQueries, (double)this.headSize, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                        break;
                    }
                    case "Wk": {
                        WeightInitUtil.initWeights((double)this.nInKeys, (double)this.headSize, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                        break;
                    }
                    case "Wv": {
                        WeightInitUtil.initWeights((double)this.nInValues, (double)this.headSize, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                        break;
                    }
                    case "Wo": {
                        WeightInitUtil.initWeights((double)((long)this.nHeads * this.headSize), (double)this.nOut, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
                    }
                }
            }
        }
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) {
        if (maskArrays != null) {
            if (maskArrays[0] == null) {
                return null;
            }
            return Pair.of((Object)maskArrays[0], (Object)((Object)currentMaskState));
        }
        return Pair.of(null, (Object)((Object)currentMaskState));
    }

    @Override
    public SDVariable defineVertex(SameDiff sameDiff, Map<String, SDVariable> layerInput, Map<String, SDVariable> paramTable, Map<String, SDVariable> maskVars) {
        SDVariable attention;
        SDVariable mask;
        SDVariable queries = layerInput.get("queries");
        SDVariable keys = layerInput.get("keys");
        SDVariable values = layerInput.get("values");
        SDVariable sDVariable = mask = maskVars != null ? sameDiff.min(maskVars.get("keys"), maskVars.get("values")) : null;
        if (this.projectInput) {
            SDVariable Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
            SDVariable Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
            SDVariable Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
            SDVariable Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
            attention = sameDiff.nn.multiHeadDotProductAttention(this.getLayerName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true);
        } else {
            attention = sameDiff.nn.dotProductAttention(this.getLayerName(), queries, keys, values, mask, true);
        }
        if (maskVars != null) {
            return attention.mul(sameDiff.expandDims(maskVars.get("queries"), 1));
        }
        return attention;
    }

    public AttentionVertex() {
    }

    public long getNInKeys() {
        return this.nInKeys;
    }

    public long getNInValues() {
        return this.nInValues;
    }

    public long getNInQueries() {
        return this.nInQueries;
    }

    public long getNOut() {
        return this.nOut;
    }

    public long getHeadSize() {
        return this.headSize;
    }

    public int getNHeads() {
        return this.nHeads;
    }

    public boolean isProjectInput() {
        return this.projectInput;
    }

    public WeightInit getWeightInit() {
        return this.weightInit;
    }

    public void setNInKeys(long nInKeys) {
        this.nInKeys = nInKeys;
    }

    public void setNInValues(long nInValues) {
        this.nInValues = nInValues;
    }

    public void setNInQueries(long nInQueries) {
        this.nInQueries = nInQueries;
    }

    public void setNOut(long nOut) {
        this.nOut = nOut;
    }

    public void setHeadSize(long headSize) {
        this.headSize = headSize;
    }

    public void setNHeads(int nHeads) {
        this.nHeads = nHeads;
    }

    public void setProjectInput(boolean projectInput) {
        this.projectInput = projectInput;
    }

    public void setWeightInit(WeightInit weightInit) {
        this.weightInit = weightInit;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AttentionVertex)) {
            return false;
        }
        AttentionVertex other = (AttentionVertex)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getNInKeys() != other.getNInKeys()) {
            return false;
        }
        if (this.getNInValues() != other.getNInValues()) {
            return false;
        }
        if (this.getNInQueries() != other.getNInQueries()) {
            return false;
        }
        if (this.getNOut() != other.getNOut()) {
            return false;
        }
        if (this.getHeadSize() != other.getHeadSize()) {
            return false;
        }
        if (this.getNHeads() != other.getNHeads()) {
            return false;
        }
        if (this.isProjectInput() != other.isProjectInput()) {
            return false;
        }
        WeightInit this$weightInit = this.getWeightInit();
        WeightInit other$weightInit = other.getWeightInit();
        return !(this$weightInit == null ? other$weightInit != null : !((Object)((Object)this$weightInit)).equals((Object)other$weightInit));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof AttentionVertex;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $nInKeys = this.getNInKeys();
        result = result * 59 + (int)($nInKeys >>> 32 ^ $nInKeys);
        long $nInValues = this.getNInValues();
        result = result * 59 + (int)($nInValues >>> 32 ^ $nInValues);
        long $nInQueries = this.getNInQueries();
        result = result * 59 + (int)($nInQueries >>> 32 ^ $nInQueries);
        long $nOut = this.getNOut();
        result = result * 59 + (int)($nOut >>> 32 ^ $nOut);
        long $headSize = this.getHeadSize();
        result = result * 59 + (int)($headSize >>> 32 ^ $headSize);
        result = result * 59 + this.getNHeads();
        result = result * 59 + (this.isProjectInput() ? 79 : 97);
        WeightInit $weightInit = this.getWeightInit();
        result = result * 59 + ($weightInit == null ? 43 : ((Object)((Object)$weightInit)).hashCode());
        return result;
    }

    @Override
    public String toString() {
        return "AttentionVertex(nInKeys=" + this.getNInKeys() + ", nInValues=" + this.getNInValues() + ", nInQueries=" + this.getNInQueries() + ", nOut=" + this.getNOut() + ", headSize=" + this.getHeadSize() + ", nHeads=" + this.getNHeads() + ", projectInput=" + this.isProjectInput() + ", weightInit=" + this.getWeightInit() + ")";
    }

    public static class Builder {
        private long nInKeys = 0L;
        private long nInValues = 0L;
        private long nInQueries = 0L;
        private long nOut = 0L;
        private long headSize = 0L;
        private int nHeads = 1;
        private boolean projectInput;
        protected WeightInit weightInit;

        public Builder nInKeys(long nInKeys) {
            this.nInKeys = nInKeys;
            return this;
        }

        public Builder nInQueries(long nInQueries) {
            this.nInQueries = nInQueries;
            return this;
        }

        public Builder nInValues(long nInValues) {
            this.nInValues = nInValues;
            return this;
        }

        public Builder headSize(long headSize) {
            this.headSize = headSize;
            return this;
        }

        public Builder nHeads(int nHeads) {
            this.nHeads = nHeads;
            return this;
        }

        public Builder nOut(long nOut) {
            this.nOut = nOut;
            return this;
        }

        public Builder weightInit(WeightInit weightInit) {
            this.weightInit = weightInit;
            return this;
        }

        public Builder projectInput(boolean projectInput) {
            this.projectInput = projectInput;
            return this;
        }

        public AttentionVertex build() {
            this.nHeads = this.nHeads == 0 ? 1 : this.nHeads;
            this.weightInit = this.weightInit == null ? WeightInit.XAVIER : this.weightInit;
            Preconditions.checkArgument((this.nOut > 0L ? 1 : 0) != 0, (Object)"You have to set nOut");
            Preconditions.checkArgument((this.nInKeys > 0L ? 1 : 0) != 0, (Object)"You have to set nInKeys");
            Preconditions.checkArgument((this.nInQueries > 0L ? 1 : 0) != 0, (Object)"You have to set nInQueries");
            Preconditions.checkArgument((this.nInValues > 0L ? 1 : 0) != 0, (Object)"You have to set nInValues");
            Preconditions.checkArgument((this.headSize > 0L || this.nOut % (long)this.nHeads == 0L ? 1 : 0) != 0, (Object)"You have to set a head size if nOut isn't cleanly divided by nHeads");
            Preconditions.checkArgument((this.projectInput || this.nInQueries == this.nInKeys && this.nInKeys == this.nInValues && this.nInValues == this.nOut && this.nHeads == 1 ? 1 : 0) != 0, (Object)"You may only disable projectInput if all nIn* equal to nOut and you want to use only a single attention head");
            this.headSize = this.headSize == 0L ? this.nOut / (long)this.nHeads : this.headSize;
            return new AttentionVertex(this);
        }

        public long getNInKeys() {
            return this.nInKeys;
        }

        public long getNInValues() {
            return this.nInValues;
        }

        public long getNInQueries() {
            return this.nInQueries;
        }

        public long getNOut() {
            return this.nOut;
        }

        public long getHeadSize() {
            return this.headSize;
        }

        public int getNHeads() {
            return this.nHeads;
        }

        public boolean isProjectInput() {
            return this.projectInput;
        }

        public WeightInit getWeightInit() {
            return this.weightInit;
        }

        public void setNInKeys(long nInKeys) {
            this.nInKeys = nInKeys;
        }

        public void setNInValues(long nInValues) {
            this.nInValues = nInValues;
        }

        public void setNInQueries(long nInQueries) {
            this.nInQueries = nInQueries;
        }

        public void setNOut(long nOut) {
            this.nOut = nOut;
        }

        public void setHeadSize(long headSize) {
            this.headSize = headSize;
        }

        public void setNHeads(int nHeads) {
            this.nHeads = nHeads;
        }

        public void setProjectInput(boolean projectInput) {
            this.projectInput = projectInput;
        }

        public void setWeightInit(WeightInit weightInit) {
            this.weightInit = weightInit;
        }
    }
}

