Implement Scaled Dot-Product Attention in PyTorch (from scratch)
Context
You will implement a numerically stable, vectorized scaled dot-product attention module in PyTorch and validate it with unit tests. Assume Q, K, V can have different sequence lengths (Lq vs Lk) and optionally multiple heads.
Requirements
-
Implementation
-
Create a PyTorch module that computes scaled dot-product attention:
-
Inputs: Q ∈ R^{B×H×Lq×d_k}, K ∈ R^{B×H×Lk×d_k}, V ∈ R^{B×H×Lk×d_v} (allow H=1 when working without heads; also accept 3D by treating it as H=1).
-
Compute scores S = Q K^T, apply scaling by 1/sqrt(d_k) (unless a custom scale is provided).
-
Support masking:
-
Causal mask (upper-triangular).
-
Padding/attention masks (broadcastable to B×H×Lq×Lk). Boolean masks should treat True as "masked out"; additive masks should be added to scores (e.g., 0 for keep, -inf for disallow).
-
Apply numerically stable softmax to get attention weights; apply optional dropout on weights.
-
Return both attention output (B×H×Lq×d_v) and attention weights (B×H×Lq×Lk).
-
Ensure fp16/bf16 inputs are handled by upcasting to float32 for the softmax and downcasting afterward.
-
Ensure rows that are fully masked produce zero attention weights and zero outputs (no NaNs).
-
Tests
-
Shape correctness: multiple heads, different Lq and Lk, different d_k and d_v.
-
Gradient flow: backprop from a scalar loss; gradients are finite and nonzero for Q, K, V.
-
Numerical stability: very large-magnitude Q/K should not produce NaNs/Inf; rows fully masked should produce zero weights/output.
-
Masking correctness: padding mask and causal mask both behave as expected.
-
Explanations
-
Explain why scaling by 1/sqrt(d_k) is used (variance control, softmax non-saturation) with a short derivation/intuition.
-
Clarify when and why you apply normalization:
-
Softmax for attention weights (across keys for each query).
-
LayerNorm (or RMSNorm) around the attention block for training stability, not as a replacement for softmax.
-
Analyze time and memory complexity.
-
Propose optimizations for long sequences.