Debug Transformer and Add KV Cache
Company: OpenAI
Role: Machine Learning Engineer
Category: Machine Learning
Difficulty: medium
Interview Round: Onsite
You are given a small decoder-only transformer (GPT-style) implemented in PyTorch for autoregressive (next-token) language modeling. The starter code includes token/positional embeddings, a stack of masked self-attention blocks, an LM head, a training loop, and a naive `generate` function. A separate `KVCache` class for storing attention keys and values is also provided.
This is a two-part hands-on coding exercise: first make training correct, then add an inference-time KV cache and prove it is numerically equivalent to the uncached path.
### Constraints & Assumptions
- Decoder-only (GPT-style) architecture: causal multi-head self-attention + MLP blocks with residual connections and LayerNorm, learned (not sinusoidal) token and positional embeddings, and a final vocabulary head.
- Training objective is next-token prediction (causal LM) with cross-entropy loss.
- You may run small training/eval loops to inspect behavior; assume CPU or a single GPU and a tiny toy dataset (e.g. character-level) so you can iterate in seconds.
- The four Part 1 bugs are real correctness bugs in the *given* code — not architectural choices. After fixing, the model must be able to overfit a tiny batch.
- For Part 2, the goal is an *optimization that does not change outputs*: with dropout disabled and deterministic (greedy) decoding, cached and uncached generation must emit the identical token sequence.
### Clarifying Questions to Ask
- Does the model's `forward` return logits for the full input length `[B, T, V]`, and is the cross-entropy loss expected to be computed inside or outside `forward`?
- Are the positional embeddings a learned `nn.Embedding`/`nn.Parameter` of shape `[max_seq_len, d_model]`, and what is `max_seq_len`?
- Is the attention mask additive (added to scores pre-softmax) or boolean, and what is its expected shape/broadcast convention (`[B, H, T, T]`)?
- Does `generate` decode greedily or with sampling, and is there a fixed seed I should hold constant when comparing the cached vs. uncached paths?
- Is the `KVCache` per-layer (a list of caches) or a single shared object, and what is its exact append/read API?
### Part 1: Debug the training code
The training code contains **four** bugs. Fix them so training loss decreases and generated text becomes plausible. The bug categories are:
1. **Label shift** — the LM loss uses the wrong target alignment (input/target slicing).
2. **Positional embedding initialization** — the positional embeddings are initialized incorrectly.
3. **Attention mask** — the causal mask is wrong.
4. **A typo-related bug** — one additional silent bug somewhere in the code (e.g. a misspelled field, a swapped tensor, an off-by-one, or a dropped transform).
After your fixes, the model should train normally, loss should go down, and sampled outputs should look reasonable.
```hint Where to start
Don't read top-to-bottom hunting for the typo. First confirm the model can **overfit a single tiny batch** — if it can't drive loss toward zero on a handful of tokens, the bug is structural (loss/mask/embedding), not the typo. Then bisect: print tensor shapes and a few values at each stage (embeddings → attention scores → logits → loss).
```
```hint Bugs 1-3
Next-token LM means position $t$ predicts token $t{+}1$: align `logits[:, :-1]` with `targets[:, 1:]` (or `input=x[:, :-1]`, `target=x[:, 1:]`). A loss that scores a position against *itself* lets the model cheat with an identity map. The causal mask must be strictly upper-triangular (entries with $j>i$ set to a large negative / $-\infty$ **before** softmax, not $0$) and must broadcast over batch and heads. For positional embeddings, check shape `[max_seq_len, d_model]`, that they are trainable, and that the init magnitude is small — a huge init drowns the token embeddings.
```
```hint The typo bug
Silent typos pass the wrong-but-valid tensor or skip a transform without crashing: a residual adding the sublayer *input* instead of its output, reusing `q` where `k`/`v` was meant, a dropped LayerNorm, or scaling by $1/\sqrt{d_{model}}$ instead of $1/\sqrt{d_{head}}$. After fixing bugs 1-3, if the tiny-batch overfit test still fails or is suspiciously slow, that residual failure points you at the typo.
```
#### Clarifying Questions for this Part
- Are dropout and weight tying (between the input embedding and the LM head) part of the intended design, so I don't "fix" something that is actually correct?
- Should the attention scores be scaled by $1/\sqrt{d_{head}}$ in the given code, or is that handled elsewhere?
#### What This Part Should Cover
- Correctly identifies and fixes all four bugs, and *explains why each is wrong* (not just patches it): the identity-mapping failure from a missing shift, future-information leakage from a bad mask, training instability from bad positional init.
- A disciplined debugging method — overfit-a-tiny-batch as the oracle, shape/value printing, bisection — rather than guess-and-check.
- A correctness signal after fixing: loss drops on a toy dataset and generated text stops being random noise.
### Part 2: Add and validate a KV cache
A `KVCache` class for attention keys and values is provided. Extend the model to support **cached autoregressive decoding** so each new step does **not** recompute keys/values for the whole prefix. Update the implementation so that:
1. The attention module **reads from and appends to** the cache correctly (use the cached `K`/`V` for the new query, then store the new `K`/`V`).
2. **Causal masking is only applied when there is no cache or the cache is empty** — a single new token attending to an all-past cache needs no full causal mask.
3. **Positional encodings use the correct offset** based on the current cache length (the new token is at position `cache_len`, not `0`).
4. The **`generate` function performs incremental decoding** with the cache (pre-fill the prompt, then feed one token at a time).
Finally, verify that **deterministic generation with and without the KV cache produces the same token sequence**.
```hint Cache append & shapes
The cache holds already-computed keys/values per layer — shapes like `k_cache, v_cache: [B, H, T_past, D_head]`. Each step, compute $q,k,v$ for only the **new** token(s) (`q: [B, H, T_new, D_head]`), append the new $k,v$ along the time axis to get `[B, H, T_past + T_new, D_head]`, then attend the new query against the **full** concatenated K/V. Keep one cache per attention layer.
```
```hint Masking & position offset
When you decode one token against a cache of only past keys, the query cannot see a future position — so a full causal mask is unnecessary for the single-new-token case. You still need the causal mask on the pre-fill pass (cache empty / multiple tokens at once), and a local mask if you ever feed more than one new token per step. For positions, use `positions = arange(T_new) + cache_len`; forgetting the offset reuses early position indices and silently diverges from the uncached path.
```
```hint Proving equivalence
Turn dropout off (`model.eval()`), decode greedily (argmax) with the same prompt and seed on both paths, and compare token IDs one-by-one. They must match exactly. If they diverge at step $k$, inspect the position indices and the mask at exactly that step first, then the order of cache appends.
```
#### Clarifying Questions for this Part
- Does the existing `forward` need to optionally accept and return per-layer cache state, or should I subclass/wrap the model to thread it through?
- Should `generate` pre-fill the prompt in a single forward (then loop one token), or feed the prompt token-by-token to build the cache?
#### What This Part Should Cover
- A correct, layer-wise cache: new K/V appended in time order, query attends against the full prefix, cache returned/persisted across steps.
- Right masking discipline: no full causal mask for single-token-with-cache, but a causal mask retained for multi-token pre-fill.
- Correct position offset (`+ cache_len`) and a clear explanation of why omitting it breaks equivalence.
- An explicit equivalence test (eval mode + greedy + fixed seed) and an articulated complexity argument: per-step attention recompute drops from $O(T^2)$ over the run to $O(T)$ work per new token.
### What a Strong Answer Covers
These dimensions span both parts:
- Treats the KV cache as an **inference-only optimization that must not change outputs** — never letting it alter training or numerical results.
- Uses **reproducible validation** throughout: overfit-a-batch for Part 1, exact token-sequence equality for Part 2, with dropout off and decoding held deterministic.
- Communicates the *why* behind each fix and each cache decision, not just the code edit.
### Follow-up Questions
- What is the time and memory complexity of generating $T$ tokens with vs. without the KV cache? Estimate cache memory for $L$ layers, $H$ heads, head dim $D$, sequence length $T$, batch size $B$, and discuss when memory (not compute) becomes the bottleneck for long contexts.
- How would you extend the cache to **batched decoding** with prompts of different lengths (left-padding, ragged caches, per-sequence position offsets), and what breaks if the per-sequence offsets are wrong?
- Sketch how this changes for **rotary position embeddings (RoPE)** — where does the offset get applied, and does the cache store pre- or post-rotation keys?
- How do **multi-query attention (MQA)** or **grouped-query attention (GQA)** reduce the KV-cache footprint, and what is the quality trade-off?
- If equivalence fails by *one* token deep into a long generation, how would you localize the cause — what would you log, and at which layer/step?
Quick Answer: This question evaluates debugging and implementation skills for transformer-based autoregressive language models, focusing on attention mechanics, positional embeddings, causal masking, and integrating a key-value (KV) cache for incremental decoding.