models.attention
class
AttentionTrain(keras.src.layers.layer.Layer):
Custom attention layer for training Transformer models.
This layer implements the attention mechanism used in the Transformer model during training. It computes attention scores between query, key, and value tensors, applies masking to prevent attending to future positions, and performs dropout.
Attributes:
- num_heads (int): Number of attention heads.
- head_dims (int): Dimensionality of each attention head.
- k (keras.layers.Dense): Dense layer for computing the keys.
- q (keras.layers.Dense): Dense layer for computing the queries.
- v (keras.layers.Dense): Dense layer for computing the values.
- out (keras.layers.Dense): Dense layer for the output.
- q_norm (float): Normalization factor for queries.
- mask (tf.Tensor): Triangular mask tensor.
- dropout (keras.layers.Dropout): Dropout layer.
Formula:
- $Attention(K, Q, V)_{ ext{head}} = softmax ( \dfrac{QK^T}{\sqrt{d_k}}) V$ for each head
References: - Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
Example:
>>> attn = AttentionTrain(32, 40)
>>> print(attn(keras.ops.ones((1, 1, 1280))
AttentionTrain(num_heads, head_dims, dropout=0.2, input_len=64)
Initializes the AttentionTrain layer.
Args:
- num_heads (int): Number of attention heads.
- head_dims (int): Dimensionality of each attention head.
- dropout (float): Dropout rate. Default is 0.2.
- input_len (int): Length of the input sequence. Default is 64.
def
generate_mask(self, num_words):
Generates a triangular mask to be applied to attention scores to prevent attending to future positions.
Args:
- num_words (int): Number of words in the sequence.
Returns:
- tf.Tensor: Triangular mask tensor.
def
call(self, inputs):
Executes the forward pass of the AttentionTrain layer.
Args:
- inputs: Input tensor.
Returns:
- keras.Tensor: Output tensor.
Inherited Members
- keras.src.layers.layer.Layer
- get_build_config
- build_from_config
- add_variable
- add_weight
- trainable
- variables
- trainable_variables
- non_trainable_variables
- weights
- trainable_weights
- non_trainable_weights
- metrics
- metrics_variables
- get_weights
- set_weights
- dtype
- compute_dtype
- variable_dtype
- input_dtype
- supports_masking
- stateless_call
- add_loss
- losses
- save_own_variables
- load_own_variables
- count_params
- get_config
- keras.src.ops.operation.Operation
- from_config
- input
- output