models.attention

class AttentionTrain(keras.src.layers.layer.Layer):

Custom attention layer for training Transformer models.

This layer implements the attention mechanism used in the Transformer model during training. It computes attention scores between query, key, and value tensors, applies masking to prevent attending to future positions, and performs dropout.

Attributes:

  • num_heads (int): Number of attention heads.
  • head_dims (int): Dimensionality of each attention head.
  • k (keras.layers.Dense): Dense layer for computing the keys.
  • q (keras.layers.Dense): Dense layer for computing the queries.
  • v (keras.layers.Dense): Dense layer for computing the values.
  • out (keras.layers.Dense): Dense layer for the output.
  • q_norm (float): Normalization factor for queries.
  • mask (tf.Tensor): Triangular mask tensor.
  • dropout (keras.layers.Dropout): Dropout layer.

Formula:

  • $Attention(K, Q, V)_{ ext{head}} = softmax ( \dfrac{QK^T}{\sqrt{d_k}}) V$ for each head

References: - Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).

Example:

>>> attn = AttentionTrain(32, 40)
    >>> print(attn(keras.ops.ones((1, 1, 1280))
    

AttentionTrain(num_heads, head_dims, dropout=0.2, input_len=64)

Initializes the AttentionTrain layer.

Args:

  • num_heads (int): Number of attention heads.
  • head_dims (int): Dimensionality of each attention head.
  • dropout (float): Dropout rate. Default is 0.2.
  • input_len (int): Length of the input sequence. Default is 64.
def generate_mask(self, num_words):

Generates a triangular mask to be applied to attention scores to prevent attending to future positions.

Args:

  • num_words (int): Number of words in the sequence.

Returns:

  • tf.Tensor: Triangular mask tensor.
def call(self, inputs):

Executes the forward pass of the AttentionTrain layer.

Args:

  • inputs: Input tensor.

Returns:

  • keras.Tensor: Output tensor.
Inherited Members
keras.src.layers.layer.Layer
get_build_config
build_from_config
add_variable
add_weight
trainable
variables
trainable_variables
non_trainable_variables
weights
trainable_weights
non_trainable_weights
metrics
metrics_variables
get_weights
set_weights
dtype
compute_dtype
variable_dtype
input_dtype
supports_masking
stateless_call
add_loss
losses
save_own_variables
load_own_variables
count_params
get_config
keras.src.ops.operation.Operation
from_config
input
output