PracHub
QuestionsPremiumCoachesLearningGuidesInterview Prep
|Home/Machine Learning/OpenAI

Debug Transformer and Add KV Cache

Last updated: Jun 21, 2026

Quick Overview

This question evaluates debugging and implementation skills for transformer-based autoregressive language models, focusing on attention mechanics, positional embeddings, causal masking, and integrating a key-value (KV) cache for incremental decoding.

  • medium
  • OpenAI
  • Machine Learning
  • Machine Learning Engineer

Debug Transformer and Add KV Cache

Company: OpenAI

Role: Machine Learning Engineer

Category: Machine Learning

Difficulty: medium

Interview Round: Onsite

You are given a small decoder-only transformer (GPT-style) implemented in PyTorch for autoregressive (next-token) language modeling. The starter code includes token/positional embeddings, a stack of masked self-attention blocks, an LM head, a training loop, and a naive `generate` function. A separate `KVCache` class for storing attention keys and values is also provided. This is a two-part hands-on coding exercise: first make training correct, then add an inference-time KV cache and prove it is numerically equivalent to the uncached path. ### Constraints & Assumptions - Decoder-only (GPT-style) architecture: causal multi-head self-attention + MLP blocks with residual connections and LayerNorm, learned (not sinusoidal) token and positional embeddings, and a final vocabulary head. - Training objective is next-token prediction (causal LM) with cross-entropy loss. - You may run small training/eval loops to inspect behavior; assume CPU or a single GPU and a tiny toy dataset (e.g. character-level) so you can iterate in seconds. - The four Part 1 bugs are real correctness bugs in the *given* code — not architectural choices. After fixing, the model must be able to overfit a tiny batch. - For Part 2, the goal is an *optimization that does not change outputs*: with dropout disabled and deterministic (greedy) decoding, cached and uncached generation must emit the identical token sequence. ### Clarifying Questions to Ask - Does the model's `forward` return logits for the full input length `[B, T, V]`, and is the cross-entropy loss expected to be computed inside or outside `forward`? - Are the positional embeddings a learned `nn.Embedding`/`nn.Parameter` of shape `[max_seq_len, d_model]`, and what is `max_seq_len`? - Is the attention mask additive (added to scores pre-softmax) or boolean, and what is its expected shape/broadcast convention (`[B, H, T, T]`)? - Does `generate` decode greedily or with sampling, and is there a fixed seed I should hold constant when comparing the cached vs. uncached paths? - Is the `KVCache` per-layer (a list of caches) or a single shared object, and what is its exact append/read API? ### Part 1: Debug the training code The training code contains **four** bugs. Fix them so training loss decreases and generated text becomes plausible. The bug categories are: 1. **Label shift** — the LM loss uses the wrong target alignment (input/target slicing). 2. **Positional embedding initialization** — the positional embeddings are initialized incorrectly. 3. **Attention mask** — the causal mask is wrong. 4. **A typo-related bug** — one additional silent bug somewhere in the code (e.g. a misspelled field, a swapped tensor, an off-by-one, or a dropped transform). After your fixes, the model should train normally, loss should go down, and sampled outputs should look reasonable. ```hint Where to start Don't read top-to-bottom hunting for the typo. First confirm the model can **overfit a single tiny batch** — if it can't drive loss toward zero on a handful of tokens, the bug is structural (loss/mask/embedding), not the typo. Then bisect: print tensor shapes and a few values at each stage (embeddings → attention scores → logits → loss). ``` ```hint Bugs 1-3 Next-token LM means position $t$ predicts token $t{+}1$: align `logits[:, :-1]` with `targets[:, 1:]` (or `input=x[:, :-1]`, `target=x[:, 1:]`). A loss that scores a position against *itself* lets the model cheat with an identity map. The causal mask must be strictly upper-triangular (entries with $j>i$ set to a large negative / $-\infty$ **before** softmax, not $0$) and must broadcast over batch and heads. For positional embeddings, check shape `[max_seq_len, d_model]`, that they are trainable, and that the init magnitude is small — a huge init drowns the token embeddings. ``` ```hint The typo bug Silent typos pass the wrong-but-valid tensor or skip a transform without crashing: a residual adding the sublayer *input* instead of its output, reusing `q` where `k`/`v` was meant, a dropped LayerNorm, or scaling by $1/\sqrt{d_{model}}$ instead of $1/\sqrt{d_{head}}$. After fixing bugs 1-3, if the tiny-batch overfit test still fails or is suspiciously slow, that residual failure points you at the typo. ``` #### Clarifying Questions for this Part - Are dropout and weight tying (between the input embedding and the LM head) part of the intended design, so I don't "fix" something that is actually correct? - Should the attention scores be scaled by $1/\sqrt{d_{head}}$ in the given code, or is that handled elsewhere? #### What This Part Should Cover - Correctly identifies and fixes all four bugs, and *explains why each is wrong* (not just patches it): the identity-mapping failure from a missing shift, future-information leakage from a bad mask, training instability from bad positional init. - A disciplined debugging method — overfit-a-tiny-batch as the oracle, shape/value printing, bisection — rather than guess-and-check. - A correctness signal after fixing: loss drops on a toy dataset and generated text stops being random noise. ### Part 2: Add and validate a KV cache A `KVCache` class for attention keys and values is provided. Extend the model to support **cached autoregressive decoding** so each new step does **not** recompute keys/values for the whole prefix. Update the implementation so that: 1. The attention module **reads from and appends to** the cache correctly (use the cached `K`/`V` for the new query, then store the new `K`/`V`). 2. **Causal masking is only applied when there is no cache or the cache is empty** — a single new token attending to an all-past cache needs no full causal mask. 3. **Positional encodings use the correct offset** based on the current cache length (the new token is at position `cache_len`, not `0`). 4. The **`generate` function performs incremental decoding** with the cache (pre-fill the prompt, then feed one token at a time). Finally, verify that **deterministic generation with and without the KV cache produces the same token sequence**. ```hint Cache append & shapes The cache holds already-computed keys/values per layer — shapes like `k_cache, v_cache: [B, H, T_past, D_head]`. Each step, compute $q,k,v$ for only the **new** token(s) (`q: [B, H, T_new, D_head]`), append the new $k,v$ along the time axis to get `[B, H, T_past + T_new, D_head]`, then attend the new query against the **full** concatenated K/V. Keep one cache per attention layer. ``` ```hint Masking & position offset When you decode one token against a cache of only past keys, the query cannot see a future position — so a full causal mask is unnecessary for the single-new-token case. You still need the causal mask on the pre-fill pass (cache empty / multiple tokens at once), and a local mask if you ever feed more than one new token per step. For positions, use `positions = arange(T_new) + cache_len`; forgetting the offset reuses early position indices and silently diverges from the uncached path. ``` ```hint Proving equivalence Turn dropout off (`model.eval()`), decode greedily (argmax) with the same prompt and seed on both paths, and compare token IDs one-by-one. They must match exactly. If they diverge at step $k$, inspect the position indices and the mask at exactly that step first, then the order of cache appends. ``` #### Clarifying Questions for this Part - Does the existing `forward` need to optionally accept and return per-layer cache state, or should I subclass/wrap the model to thread it through? - Should `generate` pre-fill the prompt in a single forward (then loop one token), or feed the prompt token-by-token to build the cache? #### What This Part Should Cover - A correct, layer-wise cache: new K/V appended in time order, query attends against the full prefix, cache returned/persisted across steps. - Right masking discipline: no full causal mask for single-token-with-cache, but a causal mask retained for multi-token pre-fill. - Correct position offset (`+ cache_len`) and a clear explanation of why omitting it breaks equivalence. - An explicit equivalence test (eval mode + greedy + fixed seed) and an articulated complexity argument: per-step attention recompute drops from $O(T^2)$ over the run to $O(T)$ work per new token. ### What a Strong Answer Covers These dimensions span both parts: - Treats the KV cache as an **inference-only optimization that must not change outputs** — never letting it alter training or numerical results. - Uses **reproducible validation** throughout: overfit-a-batch for Part 1, exact token-sequence equality for Part 2, with dropout off and decoding held deterministic. - Communicates the *why* behind each fix and each cache decision, not just the code edit. ### Follow-up Questions - What is the time and memory complexity of generating $T$ tokens with vs. without the KV cache? Estimate cache memory for $L$ layers, $H$ heads, head dim $D$, sequence length $T$, batch size $B$, and discuss when memory (not compute) becomes the bottleneck for long contexts. - How would you extend the cache to **batched decoding** with prompts of different lengths (left-padding, ragged caches, per-sequence position offsets), and what breaks if the per-sequence offsets are wrong? - Sketch how this changes for **rotary position embeddings (RoPE)** — where does the offset get applied, and does the cache store pre- or post-rotation keys? - How do **multi-query attention (MQA)** or **grouped-query attention (GQA)** reduce the KV-cache footprint, and what is the quality trade-off? - If equivalence fails by *one* token deep into a long generation, how would you localize the cause — what would you log, and at which layer/step?

Quick Answer: This question evaluates debugging and implementation skills for transformer-based autoregressive language models, focusing on attention mechanics, positional embeddings, causal masking, and integrating a key-value (KV) cache for incremental decoding.

Related Interview Questions

  • Implement 1NN with NumPy - OpenAI (medium)
  • Compute entropy and implement 1-NN - OpenAI (medium)
  • Defend a Research Direction and Experiment Design - OpenAI (medium)
  • Debug MiniGPT and Backpropagate Matmul - OpenAI (medium)
  • Implement Backprop for a Tiny Network - OpenAI (hard)
OpenAI logo
OpenAI
Feb 1, 2026, 12:00 AM
Machine Learning Engineer
Onsite
Machine Learning
60
0
Loading...

You are given a small decoder-only transformer (GPT-style) implemented in PyTorch for autoregressive (next-token) language modeling. The starter code includes token/positional embeddings, a stack of masked self-attention blocks, an LM head, a training loop, and a naive generate function. A separate KVCache class for storing attention keys and values is also provided.

This is a two-part hands-on coding exercise: first make training correct, then add an inference-time KV cache and prove it is numerically equivalent to the uncached path.

Constraints & Assumptions

  • Decoder-only (GPT-style) architecture: causal multi-head self-attention + MLP blocks with residual connections and LayerNorm, learned (not sinusoidal) token and positional embeddings, and a final vocabulary head.
  • Training objective is next-token prediction (causal LM) with cross-entropy loss.
  • You may run small training/eval loops to inspect behavior; assume CPU or a single GPU and a tiny toy dataset (e.g. character-level) so you can iterate in seconds.
  • The four Part 1 bugs are real correctness bugs in the given code — not architectural choices. After fixing, the model must be able to overfit a tiny batch.
  • For Part 2, the goal is an optimization that does not change outputs : with dropout disabled and deterministic (greedy) decoding, cached and uncached generation must emit the identical token sequence.

Clarifying Questions to Ask

  • Does the model's forward return logits for the full input length [B, T, V] , and is the cross-entropy loss expected to be computed inside or outside forward ?
  • Are the positional embeddings a learned nn.Embedding / nn.Parameter of shape [max_seq_len, d_model] , and what is max_seq_len ?
  • Is the attention mask additive (added to scores pre-softmax) or boolean, and what is its expected shape/broadcast convention ( [B, H, T, T] )?
  • Does generate decode greedily or with sampling, and is there a fixed seed I should hold constant when comparing the cached vs. uncached paths?
  • Is the KVCache per-layer (a list of caches) or a single shared object, and what is its exact append/read API?

Part 1: Debug the training code

The training code contains four bugs. Fix them so training loss decreases and generated text becomes plausible. The bug categories are:

  1. Label shift — the LM loss uses the wrong target alignment (input/target slicing).
  2. Positional embedding initialization — the positional embeddings are initialized incorrectly.
  3. Attention mask — the causal mask is wrong.
  4. A typo-related bug — one additional silent bug somewhere in the code (e.g. a misspelled field, a swapped tensor, an off-by-one, or a dropped transform).

After your fixes, the model should train normally, loss should go down, and sampled outputs should look reasonable.

Clarifying Questions for this Part

  • Are dropout and weight tying (between the input embedding and the LM head) part of the intended design, so I don't "fix" something that is actually correct?
  • Should the attention scores be scaled by 1/dhead1/\sqrt{d_{head}}1/dhead​​ in the given code, or is that handled elsewhere?

What This Part Should Cover

  • Correctly identifies and fixes all four bugs, and explains why each is wrong (not just patches it): the identity-mapping failure from a missing shift, future-information leakage from a bad mask, training instability from bad positional init.
  • A disciplined debugging method — overfit-a-tiny-batch as the oracle, shape/value printing, bisection — rather than guess-and-check.
  • A correctness signal after fixing: loss drops on a toy dataset and generated text stops being random noise.

Part 2: Add and validate a KV cache

A KVCache class for attention keys and values is provided. Extend the model to support cached autoregressive decoding so each new step does not recompute keys/values for the whole prefix. Update the implementation so that:

  1. The attention module reads from and appends to the cache correctly (use the cached K / V for the new query, then store the new K / V ).
  2. Causal masking is only applied when there is no cache or the cache is empty — a single new token attending to an all-past cache needs no full causal mask.
  3. Positional encodings use the correct offset based on the current cache length (the new token is at position cache_len , not 0 ).
  4. The generate function performs incremental decoding with the cache (pre-fill the prompt, then feed one token at a time).

Finally, verify that deterministic generation with and without the KV cache produces the same token sequence.

Clarifying Questions for this Part

  • Does the existing forward need to optionally accept and return per-layer cache state, or should I subclass/wrap the model to thread it through?
  • Should generate pre-fill the prompt in a single forward (then loop one token), or feed the prompt token-by-token to build the cache?

What This Part Should Cover

  • A correct, layer-wise cache: new K/V appended in time order, query attends against the full prefix, cache returned/persisted across steps.
  • Right masking discipline: no full causal mask for single-token-with-cache, but a causal mask retained for multi-token pre-fill.
  • Correct position offset ( + cache_len ) and a clear explanation of why omitting it breaks equivalence.
  • An explicit equivalence test (eval mode + greedy + fixed seed) and an articulated complexity argument: per-step attention recompute drops from O(T2)O(T^2)O(T2) over the run to O(T)O(T)O(T) work per new token.

What a Strong Answer Covers

These dimensions span both parts:

  • Treats the KV cache as an inference-only optimization that must not change outputs — never letting it alter training or numerical results.
  • Uses reproducible validation throughout: overfit-a-batch for Part 1, exact token-sequence equality for Part 2, with dropout off and decoding held deterministic.
  • Communicates the why behind each fix and each cache decision, not just the code edit.

Follow-up Questions

  • What is the time and memory complexity of generating TTT tokens with vs. without the KV cache? Estimate cache memory for LLL layers, HHH heads, head dim DDD , sequence length TTT , batch size BBB , and discuss when memory (not compute) becomes the bottleneck for long contexts.
  • How would you extend the cache to batched decoding with prompts of different lengths (left-padding, ragged caches, per-sequence position offsets), and what breaks if the per-sequence offsets are wrong?
  • Sketch how this changes for rotary position embeddings (RoPE) — where does the offset get applied, and does the cache store pre- or post-rotation keys?
  • How do multi-query attention (MQA) or grouped-query attention (GQA) reduce the KV-cache footprint, and what is the quality trade-off?
  • If equivalence fails by one token deep into a long generation, how would you localize the cause — what would you log, and at which layer/step?

Solution

Show

Submit Your Answer to Earn 20XP

Sign in to leave a comment

Loading comments...

Browse More Questions

More Machine Learning•More OpenAI•More Machine Learning Engineer•OpenAI Machine Learning Engineer•OpenAI Machine Learning•Machine Learning Engineer Machine Learning
PracHub

Master your tech interviews with 8,000+ real questions from top companies.

Product

  • Questions
  • Learning Tracks
  • Interview Guides
  • Resources
  • Premium
  • For Universities
  • Student Access

Browse

  • By Company
  • By Role
  • By Category
  • Topic Hubs
  • SQL Questions
  • Compare Platforms
  • Discord Community

Support

  • support@prachub.com
  • (916) 541-4762

Legal

  • Privacy Policy
  • Terms of Service
  • About Us

© 2026 PracHub. All rights reserved.