Package org.deeplearning4j.nn.conf.graph
Class AttentionVertex
- java.lang.Object
-
- org.deeplearning4j.nn.conf.graph.GraphVertex
-
- org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
-
- org.deeplearning4j.nn.conf.graph.AttentionVertex
-
- All Implemented Interfaces:
Serializable,Cloneable,TrainingConfig
public class AttentionVertex extends SameDiffVertex
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classAttentionVertex.Builder
-
Field Summary
Fields Modifier and Type Field Description protected WeightInitweightInit-
Fields inherited from class org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
biasUpdater, dataType, gradientNormalization, gradientNormalizationThreshold, regularization, regularizationBias, updater
-
-
Constructor Summary
Constructors Modifier Constructor Description protectedAttentionVertex(AttentionVertex.Builder builder)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description AttentionVertexclone()voiddefineParametersAndInputs(SDVertexParams params)Define the parameters - and inputs - for the network.SDVariabledefineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)Define the vertexPair<INDArray,MaskState>feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)InputTypegetOutputType(int layerIndex, InputType... vertexInputs)Determine the type of output for this GraphVertex, given the specified inputs.voidinitializeParameters(Map<String,INDArray> params)Set the initial parameter values for this layer, if required-
Methods inherited from class org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
applyGlobalConfig, applyGlobalConfigToLayer, getGradientNormalization, getGradientNormalizationThreshold, getLayerName, getMemoryReport, getRegularizationByParam, getUpdaterByParam, getVertexParams, instantiate, isPretrainParam, maxVertexInputs, minVertexInputs, numParams, paramReshapeOrder, setDataType, validateInput
-
Methods inherited from class org.deeplearning4j.nn.conf.graph.GraphVertex
equals, hashCode
-
-
-
-
Field Detail
-
weightInit
protected WeightInit weightInit
-
-
Constructor Detail
-
AttentionVertex
protected AttentionVertex(AttentionVertex.Builder builder)
-
-
Method Detail
-
clone
public AttentionVertex clone()
- Specified by:
clonein classGraphVertex
-
getOutputType
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException
Description copied from class:GraphVertexDetermine the type of output for this GraphVertex, given the specified inputs. Given that a GraphVertex may do arbitrary processing or modifications of the inputs, the output types can be quite different to the input type(s).
This is generally used to determine when to add preprocessors, as well as the input sizes etc for layers- Overrides:
getOutputTypein classSameDiffVertex- Parameters:
layerIndex- The index of the layer (if appropriate/necessary).vertexInputs- The inputs to this vertex- Returns:
- The type of output for this vertex
- Throws:
InvalidInputTypeException- If the input type is invalid for this type of GraphVertex
-
defineParametersAndInputs
public void defineParametersAndInputs(SDVertexParams params)
Description copied from class:SameDiffVertexDefine the parameters - and inputs - for the network. UseSDLayerParams.addWeightParam(String, long...)andSDLayerParams.addBiasParam(String, long...). Note also you must define (and optionally name) the inputs to the vertex. This is required so that DL4J knows how many inputs exists for the vertex.- Specified by:
defineParametersAndInputsin classSameDiffVertex- Parameters:
params- Object used to set parameters for this layer
-
initializeParameters
public void initializeParameters(Map<String,INDArray> params)
Description copied from class:SameDiffVertexSet the initial parameter values for this layer, if required- Specified by:
initializeParametersin classSameDiffVertex- Parameters:
params- Parameter arrays that may be initialized
-
feedForwardMaskArrays
public Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
- Overrides:
feedForwardMaskArraysin classSameDiffVertex
-
defineVertex
public SDVariable defineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)
Description copied from class:SameDiffVertexDefine the vertex- Specified by:
defineVertexin classSameDiffVertex- Parameters:
sameDiff- SameDiff instancelayerInput- Input to the layer - keys as defined bySameDiffVertex.defineParametersAndInputs(SDVertexParams)paramTable- Parameter table - keys as defined bySameDiffVertex.defineParametersAndInputs(SDVertexParams)maskVars- Masks of input, if available - keys as defined bySameDiffVertex.defineParametersAndInputs(SDVertexParams)- Returns:
- The final layer variable corresponding to the activations/output from the forward pass
-
-