Explain KV cache in Transformer inference
Company: OpenAI
Role: Machine Learning Engineer
Category: Software Engineering Fundamentals
Difficulty: medium
Interview Round: Onsite
## Question
In Transformer-based language model inference, what is a **key-value (KV) cache**?
Explain:
- What gets cached (tensors, shapes at a high level) and at which layers.
- Why KV caching improves autoregressive decoding latency.
- The difference between **prefill** (processing the prompt) and **decode** (generating tokens) phases.
- Tradeoffs and pitfalls: memory growth, batch/sequence management, multi-head attention, and long-context handling.
- At least two practical optimizations used in production (e.g., paged attention, quantized KV cache, sliding window).
Quick Answer: This question evaluates understanding of KV cache mechanisms in Transformer inference, including attention-state caching, memory and latency trade-offs, and engineering optimizations for autoregressive decoding.
Solution
### What KV cache is
In a decoder-only Transformer doing autoregressive generation, each layer computes self-attention over all previously seen tokens.
- For each token position *t*, each attention layer produces:
- **K** (keys) and **V** (values) vectors (per head) from the hidden states.
- During generation, when you append a new token, you *do not want to recompute* K and V for all past tokens at every step.
A **KV cache** stores the past tokens’ **K and V tensors** for each Transformer layer (and for each attention head) so that at the next decoding step you only compute K/V for the **new token** and reuse the cached ones.
High-level shape intuition (one layer):
- Without cache, at step *t* you build K,V for all positions 1..t.
- With cache, you maintain:
- `K_cache: [batch, heads, seq_len_so_far, head_dim]`
- `V_cache: [batch, heads, seq_len_so_far, head_dim]`
---
### Why it speeds up decoding
Autoregressive decoding generates one token at a time. Naively, each new token would require recomputing attention projections for the entire prefix, which is wasted work.
With KV cache:
- Each step computes Q,K,V only for the new token (plus the MLP, etc.).
- Attention for the new token uses cached K/V for previous positions.
Complexity intuition per layer:
- **No cache**: recompute projections for O(t) tokens each step → expensive repeated work.
- **With cache**: projections for O(1) new token each step; attention still attends over O(t) cached keys, but you avoid redoing K/V projection for the past.
This usually reduces per-token latency substantially, especially for large prompts and many decoding steps.
---
### Prefill vs decode
**Prefill (a.k.a. prompt processing):**
- Input: the full prompt of length P.
- You run the model on the whole prompt to produce the initial hidden states.
- You also populate the KV cache for all P positions at all layers.
- This phase is often throughput-oriented (matrix-matrix operations) and benefits from batching.
**Decode (token-by-token generation):**
- Each step adds 1 token.
- You compute K/V for just that new position and append to the cache.
- This phase is latency-sensitive and memory-bandwidth sensitive.
---
### Tradeoffs and pitfalls
1. **Memory growth**
- KV cache scales with: `layers × batch × seq_len × heads × head_dim × 2 (K and V)`.
- Long context or large batch sizes can make KV memory the bottleneck.
2. **Batching complications**
- Different sequences have different lengths and may finish at different times.
- You need careful bookkeeping to avoid wasting cache space for finished sequences.
3. **Beam search / sampling variants**
- Beam search multiplies the cache by beam width unless you share and branch carefully.
4. **Position encoding interactions**
- Absolute vs rotary/relative encodings: caching is straightforward, but you must ensure positional embeddings/rotations are applied consistently for the cached positions.
5. **Attention masking**
- Causal mask must ensure tokens only attend to earlier positions.
- For packed batches (multiple sequences in one tensor), you need per-sequence masks.
---
### Practical optimizations in production
1. **Paged attention / block-wise KV management**
- Store KV cache in fixed-size memory blocks (“pages”) and map sequences to blocks.
- Reduces fragmentation and makes variable-length batching easier.
- Enables efficient memory reuse when sequences finish.
2. **Quantized KV cache**
- Store K/V in lower precision (e.g., FP8/INT8) while keeping compute in FP16/BF16.
- Saves memory bandwidth and capacity; can improve throughput.
- Requires calibration and careful handling to avoid quality regression.
3. **Sliding window / capped context**
- For very long sequences, keep only the most recent W tokens in cache (or use special long-context methods).
- Trades some long-range dependency for bounded memory.
4. **KV cache offloading (CPU/NVMe) or multi-GPU partitioning**
- Useful when context is huge; increases latency, so typically paired with chunked attention or retrieval.
---
### How to communicate this in an interview
- Define KV cache precisely (cached K/V per layer per position).
- Explain prefill vs decode and why decode benefits most.
- Call out the central tradeoff: **latency vs memory**.
- Mention at least one scaling technique (paged attention, quantization, sliding window) and the operational constraints (variable-length batches, fragmentation, masking).