Implement multi-head attention and LLM sampling
Company: Scale AI
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: easy
Interview Round: Onsite
## Task A: Multi-head attention (forward pass)
You are implementing a Transformer attention layer.
Given:
- A sequence length `L` and model dimension `d_model`.
- Number of heads `h`, where `d_head = d_model / h` (assume divisible).
- Input matrices `Q`, `K`, `V` each of shape `(L, d_model)`.
- Projection weights `Wq`, `Wk`, `Wv` each of shape `(d_model, d_model)` and output projection `Wo` of shape `(d_model, d_model)`.
- Optional attention mask `mask` of shape `(L, L)` where `mask[i][j] = 1` means position `i` may attend to `j`, and `0` means it must not.
Compute the multi-head scaled dot-product attention output `O` of shape `(L, d_model)`:
1. Project `Q' = QWq`, `K' = KWk`, `V' = VWv`.
2. Split each into `h` heads (reshape to `(h, L, d_head)`).
3. For each head compute attention weights using scaled dot product and masking:
- `scores = (Q_head @ K_head^T) / sqrt(d_head)` producing `(L, L)`.
- Apply the mask by setting disallowed positions to `-inf` (or a very negative number) before softmax.
- `weights = softmax(scores)` over the last dimension.
- `head_out = weights @ V_head` producing `(L, d_head)`.
4. Concatenate heads back to `(L, d_model)` and apply `Wo`.
Clarify in your explanation:
- How you handle the mask numerically.
- The time and space complexity in terms of `L`, `d_model`, and `h`.
- At least 2 edge cases (e.g., fully masked row, `L=1`).
## Task B: Token sampling from logits
You are implementing a next-token sampler for an LLM.
Given:
- A vector of logits `logits` of length `V` (vocabulary size) for the next token.
- Parameters: `temperature > 0`, optional `top_k` (integer), optional `top_p` (0 < p ≤ 1), and `seed` for reproducibility.
Implement a function that returns a sampled token id.
Requirements:
- Apply temperature scaling.
- Support **top-k** filtering and/or **top-p (nucleus)** filtering.
- Convert to a probability distribution with softmax after filtering.
- Sample one token from the resulting categorical distribution.
Discuss how you would handle:
- `temperature → 0` behavior.
- Ties in logits.
- Interaction when both `top_k` and `top_p` are provided.
- Numerical stability for softmax.
Quick Answer: This question evaluates implementation and conceptual understanding of Transformer multi-head scaled dot-product attention and next-token sampling for LLMs, including attention masking semantics, softmax numerical stability, time and space complexity, and top-k/top-p with temperature-based sampling.
Part 1: Implement Multi-Head Attention
You are given already-projected query, key, and value matrices for self-attention. Each matrix has shape n x d_model, and d_model is divisible by num_heads. Split every row into num_heads contiguous chunks of size d_head = d_model // num_heads. For each head, compute scaled dot-product attention: softmax(Q_h K_h^T / sqrt(d_head)) V_h. Concatenate the outputs from all heads in head order. Do not apply masks, dropout, learned projection matrices, or residual connections. Return the final n x d_model matrix, rounding every value to 6 decimal places.
Constraints
- 0 <= n <= 50
- 1 <= d_model <= 64 when n > 0
- 1 <= num_heads <= d_model
- d_model % num_heads == 0
- query, key, and value have the same shape
- Each element is in the range [-1000, 1000]
Examples
Input: ([], [], [], 1)
Expected Output: []
Explanation: Edge case: no tokens, so the attention output is empty.
Input: ([[1, 2]], [[3, 4]], [[5, 6]], 2)
Expected Output: [[5.0, 6.0]]
Explanation: With a single token, every head attends only to that token, so the output equals the value vector.
Input: ([[0, 0], [0, 0]], [[1, 2], [3, 4]], [[10, 0], [0, 20]], 2)
Expected Output: [[5.0, 10.0], [5.0, 10.0]]
Explanation: All query vectors are zero, so every score is 0 and each head gives uniform attention over the two tokens.
Input: ([[1, 0], [0, 1]], [[1, 0], [0, 1]], [[1, 2], [3, 4]], 1)
Expected Output: [[1.660477, 2.660477], [2.339523, 3.339523]]
Explanation: This is a standard non-trivial scaled dot-product attention example with one head.
Hints
- Handle each head independently by slicing contiguous ranges of length d_head from every token vector.
- Use a numerically stable softmax by subtracting the maximum score in each row before exponentiating.
Part 2: Implement Deterministic LLM Sampling
Implement deterministic token sampling for an LLM. You are given a list of logits for a vocabulary, a temperature, a top_k value, a top_p value, and a random number r in [0, 1). Compute sampling as follows: (1) divide each logit by temperature, (2) apply softmax to get probabilities, (3) sort tokens by probability descending, breaking ties by smaller index, (4) if top_k > 0, keep only the first top_k tokens, (5) if top_p < 1, keep the smallest prefix of the remaining sorted tokens whose cumulative probability is at least top_p, (6) renormalize the kept probabilities, and (7) return the token index selected by cumulative sampling in that sorted order using r. Use the rule: return the first token whose cumulative probability is strictly greater than r. If logits is empty, return -1.
Constraints
- 0 <= len(logits) <= 100000
- -1000 <= logits[i] <= 1000
- 0 < temperature <= 100
- 0 <= top_k <= len(logits)
- 0 < top_p <= 1
- 0 <= r < 1
Examples
Input: ([], 1.0, 0, 1.0, 0.5)
Expected Output: -1
Explanation: Edge case: there are no tokens to sample from.
Input: ([5.0], 1.0, 0, 1.0, 0.3)
Expected Output: 0
Explanation: With a single token, it must always be selected.
Input: ([1.0, 0.0], 0.5, 0, 1.0, 0.9)
Expected Output: 1
Explanation: Temperature 0.5 sharpens the distribution, but r falls in the remaining tail, so the second token is chosen.
Input: ([1.0, 2.0, 3.0, 4.0], 1.0, 2, 1.0, 0.8)
Expected Output: 2
Explanation: After top-k, only the tokens with logits 4 and 3 remain. Their renormalized probabilities are about 0.731 and 0.269, so r=0.8 picks the second one.
Input: ([2.0, 1.0, 0.0], 1.0, 0, 0.7, 0.8)
Expected Output: 1
Explanation: Top-p keeps the smallest prefix with cumulative probability at least 0.7, which is the first two tokens. After renormalization, r=0.8 selects the second token in that filtered set.
Hints
- Use a numerically stable softmax by subtracting the maximum scaled logit before exponentiating.
- Apply filtering on tokens sorted by probability, not by original index, and renormalize after filtering.