This question evaluates understanding of multi-head self-attention and the competency to implement transformer attention modules using learned Q/K/V projections, head-wise tensor reshaping, attention masking, and PyTorch tensor operations.
Implement a multi-head self-attention module in PyTorch without using torch.nn.MultiheadAttention.
Requirements:
(batch_size, seq_len, d_model)
num_heads
, where
d_model % num_heads == 0
Attention(Q, K, V) = softmax(Q K^T / sqrt(head_dim)) V
(batch_size, seq_len, d_model)
You may assume standard PyTorch layers such as nn.Linear, torch.matmul, softmax, view, and transpose are available.
Explain any important tensor shape transformations and common implementation pitfalls.