Implement Grouped-Query Attention (GQA)
Company: Datadog
Role: Machine Learning Engineer
Category: Machine Learning
Difficulty: hard
Interview Round: Technical Screen
# Implement Grouped-Query Attention (GQA)
Modern decoder-only transformers spend most of their autoregressive-decoding time and memory moving the **key/value (KV) cache** in and out of GPU memory. Multi-Head Attention (MHA) keeps one K and one V per query head, which is expensive; Multi-Query Attention (MQA) shares a single K/V across *all* heads, which is cheap but can hurt quality. **Grouped-Query Attention (GQA)** is the middle ground: the query heads are partitioned into $G$ groups, and all query heads within a group share one K/V head.
Implement the forward pass of a GQA module from scratch using only basic tensor ops (linear layers + matmul + softmax; no high-level attention helper). Given:
- input hidden states $x$ of shape $(B, T, d_{\text{model}})$,
- `num_query_heads` $H$ and `num_kv_heads` $G$, with $H$ divisible by $G$,
- per-head dimension $d_{\text{head}} = d_{\text{model}} / H$,
your module should:
1. Project $x$ to queries $Q$ ($H$ heads) and keys/values $K, V$ ($G$ heads each).
2. Reshape into heads and **share each KV head across $H/G$ query heads**.
3. Compute **causal** scaled dot-product attention.
4. Concatenate heads and apply the output projection back to $d_{\text{model}}$.
```hint The crux — sharing KV heads
After projecting, $Q$ has shape $(B, H, T, d_{\text{head}})$ while $K, V$ only have $(B, G, T, d_{\text{head}})$ — every query head needs to read from one of the $G$ KV heads. Before you can run ordinary multi-head attention, you have to bring $K, V$ up to $H$ heads in a way that preserves which group each query head belongs to; getting the head ordering wrong silently scrambles which queries attend to which keys without raising any error.
```
```hint Projection sizes are asymmetric
The Q projection is $d_{\text{model}} \to H \cdot d_{\text{head}}\ (= d_{\text{model}})$, but the K and V projections are $d_{\text{model}} \to G \cdot d_{\text{head}}$, which is **smaller**. That asymmetry is the whole point — fewer KV parameters and a smaller KV cache at decode time.
```
```hint Attention core
Once $Q$, $K$, $V$ all have $H$ heads, this is ordinary causal scaled dot-product attention. Make sure the causal mask is applied **before** the softmax, not after, and double-check how a $T \times T$ mask broadcasts across the batch and head axes.
```
### Constraints & Assumptions
- Shapes: $x$ is $(B, T, d_{\text{model}})$; the output is $(B, T, d_{\text{model}})$.
- $H \bmod G = 0$ (e.g. $H=8,\ G=2$ → each KV head shared by 4 query heads). $G = H$ is MHA; $G = 1$ is MQA.
- $d_{\text{model}} \bmod H = 0$; $d_{\text{head}} = d_{\text{model}} // H$.
- Scale scores by $1/\sqrt{d_{\text{head}}}$.
- Decoder **self-attention**: apply a causal mask so position $t$ attends only to positions $\le t$.
- Single-precision floats; the mask fill should use a large negative value (or $-\infty$) compatible with softmax.
- Dropout and rotary embeddings (RoPE) may be ignored in the core implementation (mention them as extensions).
### Clarifying Questions to Ask
- Is this causal self-attention in a decoder, or cross-attention? I will assume causal self-attention.
- Should I implement only the full-sequence forward (training / prefill), or also the **incremental single-token decode** with a KV cache?
- Are positional encodings (e.g. RoPE) applied to $Q/K$ inside the module, or handled outside?
- Besides the causal mask, is there a padding mask for variable-length sequences in the batch?
- Do the Q/K/V/O linear projections include bias terms?
### What a Strong Answer Covers
```premium-lock What a Strong Answer Covers
```
### Follow-up Questions
- During autoregressive decoding with a KV cache, how much memory does GQA save versus MHA, and why is memory **bandwidth** (not FLOPs) the decode-time bottleneck?
- How would you "uptrain" an existing MHA checkpoint into GQA — how do you initialize the $G$ KV heads from the original $H$?
- Where do rotary position embeddings (RoPE) get applied, and does GQA change that?
- Implement the **incremental decode step**: given cached $K, V$ of length $t$ and one new token, produce the next output and update the cache. Which shapes change?
- How does FlashAttention interact with GQA, and what changes about the memory-access pattern?
Quick Answer: This question evaluates a candidate's understanding of transformer attention mechanisms, specifically how grouped-query attention balances the memory efficiency of multi-query attention with the quality of full multi-head attention. It tests practical implementation skill with tensor reshaping, head grouping, and causal masking, a common way to probe machine learning engineering depth in system-level model design interviews.