Implement multi-head attention and LLM sampling
Company: Scale AI
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: easy
Interview Round: Onsite
## Task A: Multi-head attention (forward pass)
You are implementing a Transformer attention layer.
Given:
- A sequence length `L` and model dimension `d_model`.
- Number of heads `h`, where `d_head = d_model / h` (assume divisible).
- Input matrices `Q`, `K`, `V` each of shape `(L, d_model)`.
- Projection weights `Wq`, `Wk`, `Wv` each of shape `(d_model, d_model)` and output projection `Wo` of shape `(d_model, d_model)`.
- Optional attention mask `mask` of shape `(L, L)` where `mask[i][j] = 1` means position `i` may attend to `j`, and `0` means it must not.
Compute the multi-head scaled dot-product attention output `O` of shape `(L, d_model)`:
1. Project `Q' = QWq`, `K' = KWk`, `V' = VWv`.
2. Split each into `h` heads (reshape to `(h, L, d_head)`).
3. For each head compute attention weights using scaled dot product and masking:
- `scores = (Q_head @ K_head^T) / sqrt(d_head)` producing `(L, L)`.
- Apply the mask by setting disallowed positions to `-inf` (or a very negative number) before softmax.
- `weights = softmax(scores)` over the last dimension.
- `head_out = weights @ V_head` producing `(L, d_head)`.
4. Concatenate heads back to `(L, d_model)` and apply `Wo`.
Clarify in your explanation:
- How you handle the mask numerically.
- The time and space complexity in terms of `L`, `d_model`, and `h`.
- At least 2 edge cases (e.g., fully masked row, `L=1`).
## Task B: Token sampling from logits
You are implementing a next-token sampler for an LLM.
Given:
- A vector of logits `logits` of length `V` (vocabulary size) for the next token.
- Parameters: `temperature > 0`, optional `top_k` (integer), optional `top_p` (0 < p ≤ 1), and `seed` for reproducibility.
Implement a function that returns a sampled token id.
Requirements:
- Apply temperature scaling.
- Support **top-k** filtering and/or **top-p (nucleus)** filtering.
- Convert to a probability distribution with softmax after filtering.
- Sample one token from the resulting categorical distribution.
Discuss how you would handle:
- `temperature → 0` behavior.
- Ties in logits.
- Interaction when both `top_k` and `top_p` are provided.
- Numerical stability for softmax.
Quick Answer: This question evaluates implementation and conceptual understanding of Transformer multi-head scaled dot-product attention and next-token sampling for LLMs, including attention masking semantics, softmax numerical stability, time and space complexity, and top-k/top-p with temperature-based sampling.
Part 1: Implement Multi-Head Attention
You are given already-projected query, key, and value matrices for self-attention. Each matrix has shape n x d_model, and d_model is divisible by num_heads. Split every row into num_heads contiguous chunks of size d_head = d_model // num_heads. For each head, compute scaled dot-product attention: softmax(Q_h K_h^T / sqrt(d_head)) V_h. Concatenate the outputs from all heads in head order. Do not apply masks, dropout, learned projection matrices, or residual connections. Return the final n x d_model matrix, rounding every value to 6 decimal places.