Implement Multi-Head Self-Attention (from scratch)
Context
You are given an input tensor X with shape (batch_size, seq_len, d_model). Implement a multi-head self-attention layer (forward pass) using PyTorch or NumPy that:
-
Projects inputs into queries (Q), keys (K), and values (V).
-
Splits into h heads with per-head dimension d_k = d_model / h.
-
Computes scaled dot-product attention with optional padding and causal masks.
-
Concatenates heads and applies an output projection.
Assume d_model is divisible by h.
Requirements
-
Implement the forward pass with correct tensor shapes and transpositions.
-
Support optional masks:
-
Padding mask (e.g., shape (batch_size, seq_len) or broadcastable variants).
-
Causal mask (prevent attending to future positions).
-
Explain the shape of each intermediate tensor.
-
Analyze time and memory complexity.
-
Discuss numerical stability (e.g., scaling, masking, softmax stability, mixed precision).