Implement Multi‑Head Attention and Nucleus (Top‑p) Sampling
Context
You are building core components used in Transformer-based language models. Implement multi-head attention (MHA) from scratch and nucleus (top-p) sampling. Assume a deep learning framework (e.g., PyTorch) is available. Clearly handle shapes, masking, and numerical stability.
Tasks
-
Multi‑Head Attention Implementation
-
Inputs:
-
Query Q, Key K, Value V: tensors of shape (batch_size, seq_len, d_model)
-
Number of heads h (assume d_model % h == 0)
-
Optional attention mask (broadcastable to attention scores shape)
-
Requirements:
-
Linear projections to head dimensions d_k = d_v = d_model / h
-
Scaled dot‑product attention with softmax
-
Support masking (e.g., padding or causal masks)
-
Apply dropout to attention weights
-
Concatenate heads and apply a final output projection
-
Output shape: (batch_size, seq_len, d_model)
Follow‑up A: Why is the dot‑product attention scaled by sqrt(d_k)?
-
Nucleus (Top‑p) Sampling Implementation
-
Inputs:
-
Logits over a vocabulary (shape (vocab_size,) or (batch_size, vocab_size))
-
Threshold p in (0, 1]
-
Requirements:
-
Convert logits to probabilities
-
Select the smallest set of tokens whose cumulative probability ≥ p
-
Renormalize the selected probabilities
-
Sample the next token from this set
Follow‑up B: What are the advantages and disadvantages of top‑p sampling compared with top‑k sampling?