PracHub
QuestionsCoachesLearningGuidesInterview Prep
|Home/Machine Learning/OpenAI

Debug and fix a PyTorch Transformer training loop

Last updated: Jun 21, 2026

Quick Overview

This question evaluates a candidate's skills in implementing and debugging a Transformer-based causal language model in PyTorch, including model architecture correctness, attention masking, numerical stability, training-loop hygiene, unit tests, and analysis of attention time/memory complexity.

  • hard
  • OpenAI
  • Machine Learning
  • Data Scientist

Debug and fix a PyTorch Transformer training loop

Company: OpenAI

Role: Data Scientist

Category: Machine Learning

Difficulty: hard

Interview Round: Onsite

You’re given a minimal causal language model in PyTorch that “trains” but the loss never improves and sometimes becomes NaN. Identify every bug and provide a corrected implementation. For each bug, explain: (a) the symptom it causes, (b) the root cause, and (c) the minimal code fix. Then: (1) write 3 unit tests that would have caught these issues before training, (2) derive the time and memory complexity of scaled dot-product attention in terms of batch B, heads H, sequence length S, and head dimension d_k, and (3) propose two changes that reduce memory without degrading perplexity too much, explaining trade-offs. Faulty code: """ import torch import torch.nn as nn import torch.nn.functional as F torch.manual_seed(0) class TinyDecoder(nn.Module): def __init__(self, vocab_size, d_model=64, n_heads=4): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_heads = n_heads self.tok = nn.Embedding(vocab_size, d_model) # BUG: no padding_idx set self.pos = nn.Parameter(torch.zeros(d_model)) # BUG: single vector used for all positions self.qkv = nn.Linear(d_model, 3 * d_model) self.proj = nn.Linear(d_model, d_model) self.ln = nn.LayerNorm(d_model) self.out = nn.Linear(d_model, vocab_size) self.drop = nn.Dropout(0.2) def forward(self, x, attn_mask=None): B, S = x.shape h = self.tok(x) + self.pos # BUG: broadcasts same pos to every time step qkv = self.qkv(h) q, k, v = qkv.chunk(3, dim=-1) H = self.n_heads q = q.view(B, S, H, -1) k = k.view(B, S, H, -1) v = v.view(B, S, H, -1) attn = torch.matmul(q, k.transpose(-2, -1)) # BUG: missing 1/sqrt(d_k) scale if attn_mask is not None: attn = attn.masked_fill(attn_mask == 0, -1e9) # BUG: mask shape/dtype likely wrong for heads w = F.softmax(attn, dim=0) # BUG: softmax on dim 0, should be last dim z = torch.matmul(w, v).view(B, S, -1) h2 = self.proj(z) h3 = self.ln(h + self.drop(h2)) # residual then LN is ok but training mode matters return self.out(h3).softmax(-1) # BUG: softmax before CE loss def train(): device = 'cuda' if torch.cuda.is_available() else 'cpu' model = TinyDecoder(vocab_size=100).to('cpu') # BUG: device mismatch opt = torch.optim.AdamW(model.parameters(), lr=1e-3) model.eval() # BUG: wrong mode; disables dropout, affects LN stats for step in range(200): x = torch.randint(1, 100, (32, 16), device=device) y = x # BUG: targets not next-token shifted logits = model(x) # BUG: no causal mask passed loss = F.cross_entropy(F.log_softmax(logits, 2), # BUG: applying log_softmax before CE y.float()) # BUG: targets are float; CE expects Long with torch.no_grad(): loss.backward() # BUG: backward under no_grad prevents grads opt.step() # BUG: missing optimizer.zero_grad() if step % 50 == 0: print(step, loss.item()) train() """ Deliverables: (A) A bullet list of at least 12 distinct defects across architecture, masking, numerics, and the training loop; (B) corrected, runnable code that trains to decreasing loss on random next-token data (use a proper causal mask, shift targets, correct tensor shapes/dtypes/devices, and avoid redundant softmax); (C) three PyTest-style unit tests (e.g., shape/contiguity checks, mask correctness, gradient non-None assertions) that fail on the buggy code and pass on your fix; (D) complexity derivation for attention and two memory-reduction strategies (e.g., FlashAttention, sequence chunking, low-rank KV cache) with pros/cons.

Quick Answer: This question evaluates a candidate's skills in implementing and debugging a Transformer-based causal language model in PyTorch, including model architecture correctness, attention masking, numerical stability, training-loop hygiene, unit tests, and analysis of attention time/memory complexity.

Related Interview Questions

  • Implement 1NN with NumPy - OpenAI (medium)
  • Compute entropy and implement 1-NN - OpenAI (medium)
  • Defend a Research Direction and Experiment Design - OpenAI (medium)
  • Implement Backprop for a Tiny Network - OpenAI (hard)
  • Debug MiniGPT and Backpropagate Matmul - OpenAI (medium)
|Home/Machine Learning/OpenAI

Debug and fix a PyTorch Transformer training loop

OpenAI logo
OpenAI
Oct 13, 2025, 9:49 PM
hardData ScientistOnsiteMachine Learning
259
0

Minimal Causal LM Debugging and Optimization

You are given a tiny causal decoder-only language model implemented in PyTorch. It appears to "train" but the loss does not improve and sometimes becomes NaN. The provided code contains multiple issues spanning model architecture, attention masking, numerics, and the training loop.

Your task is to identify all defects, explain their impact, and provide a corrected, runnable implementation that demonstrates a decreasing loss on a simple next-token prediction task. You will also write unit tests that would have caught the issues early and analyze the time/memory complexity of scaled dot-product attention, plus memory-reduction strategies.

This is a debugging-under-time-pressure exercise (think 60 minutes). It tests your knowledge of Transformer architecture, your ability to trace a symptom back to its root cause, and your attention to detail. Treat the inline # BUG: comments as hints, not as an exhaustive list — read the code critically and find issues for yourself.

Faulty Code

"""
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

class TinyDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=4):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_heads = n_heads
        self.tok = nn.Embedding(vocab_size, d_model)      # BUG: no padding_idx set
        self.pos = nn.Parameter(torch.zeros(d_model))     # BUG: single vector for all positions
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)
        self.drop = nn.Dropout(0.2)

    def forward(self, x, attn_mask=None):
        B, S = x.shape
        h = self.tok(x) + self.pos                        # BUG: broadcasts same pos to every step
        qkv = self.qkv(h)
        q, k, v = qkv.chunk(3, dim=-1)
        H = self.n_heads
        q = q.view(B, S, H, -1)
        k = k.view(B, S, H, -1)
        v = v.view(B, S, H, -1)
        attn = torch.matmul(q, k.transpose(-2, -1))       # BUG: missing 1/sqrt(d_k) scale
        if attn_mask is not None:
            attn = attn.masked_fill(attn_mask == 0, -1e9) # BUG: mask shape/dtype likely wrong for heads
        w = F.softmax(attn, dim=0)                        # BUG: softmax on dim 0, should be last dim
        z = torch.matmul(w, v).view(B, S, -1)
        h2 = self.proj(z)
        h3 = self.ln(h + self.drop(h2))                   # residual then LN is ok but training mode matters
        return self.out(h3).softmax(-1)                   # BUG: softmax before CE loss


def train():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = TinyDecoder(vocab_size=100).to('cpu')         # BUG: device mismatch
    opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
    model.eval()                                          # BUG: wrong mode; disables dropout, affects LN
    for step in range(200):
        x = torch.randint(1, 100, (32, 16), device=device)
        y = x                                             # BUG: targets not next-token shifted
        logits = model(x)                                 # BUG: no causal mask passed
        loss = F.cross_entropy(F.log_softmax(logits, 2),  # BUG: applying log_softmax before CE
                               y.float())                 # BUG: targets are float; CE expects Long
        with torch.no_grad():
            loss.backward()                               # BUG: backward under no_grad prevents grads
        opt.step()                                        # BUG: missing optimizer.zero_grad()
        if step % 50 == 0:
            print(step, loss.item())

train()
"""

Constraints & Assumptions

  • Framework: Python 3 + PyTorch 2.x ( torch , torch.nn , torch.nn.functional ). You may assume a recent PyTorch is available.
  • Model scale (toy): vocab_size = 100 , d_model = 64 , n_heads = 4 , batch B = 32 , sequence length S = 16 . Keep the architecture decoder-only and small enough to run on CPU.
  • Objective: standard autoregressive next-token prediction with cross-entropy loss.
  • Determinism: assume torch.manual_seed(0) is set so a fixed run is reproducible.
  • Goal of "working": the corrected loop must show a clearly decreasing loss curve on a task that actually contains learnable signal; it should run on CPU in seconds, with no NaN / inf .

Clarifying Questions to Ask

  • Should I treat the inline # BUG: comments as the complete, authoritative list, or are there additional defects that are not annotated?
  • Is the goal a minimal fix per bug (preserving the existing structure), or am I free to refactor the module for clarity?
  • Is it acceptable to invent a small deterministic toy task to demonstrate a falling loss, or must I train on the original i.i.d. random data?
  • Are CPU-only results acceptable for the demonstration, or do you want it to run on GPU?
  • For the complexity analysis, do you want just the attention block, or the full forward pass including the linear / FFN projections?
  • How rigorous should the unit tests be — illustrative assertions, or genuinely runnable PyTest functions?

Part A — Defect inventory (find at least 12)

Produce a bullet list of at least 12 distinct defects, grouped by subsystem (model architecture / forward math, attention masking & numerics, and the training loop). For each defect give: (a) the symptom it causes, (b) the root cause, and (c) the minimal code fix.

What This Part Should Cover

  • Coverage & grouping: finds comfortably past the minimum count and organizes defects by subsystem rather than as a flat dump.
  • Symptom → root cause → fix discipline: for each defect, ties the observed behavior to the mechanism and gives a minimal, correct change.
  • The non-crashing bugs: catches the silent ones (wrong attention axis, wrong softmax dimension, double softmax) — not just the ones that throw.
  • Two failure classes: separates bugs that block learning from bugs that corrupt numerics , and recognizes both must be fixed.

Part B — Corrected, runnable implementation

Provide corrected, runnable code that trains to a decreasing loss on a simple next-token task. It must use a proper causal mask, correct next-token shifting, correct tensor shapes / dtypes / devices, and no redundant softmax. Show (or describe) a loss curve that clearly goes down.

What This Part Should Cover

  • Correct attention mechanics: head layout (heads on the batch axis), 1/dk1/\sqrt{d_k}1/dk​​ scaling, softmax axis, causal masking, and shape bookkeeping all correct in the rewrite.
  • A learnable demonstration: understands that a flat loss can mean either a bug or an unlearnable task, and constructs data that actually proves the fixes took.
  • End-to-end wiring: correct device placement, train() mode, zero_grad , next-token shift, and raw-logit cross-entropy with long targets.

Part C — Unit tests

Write three PyTest-style unit tests that would fail on the buggy code and pass on your fix. Good targets include: attention shape / contiguity, softmax-axis and mask/causality correctness, raw-logits-vs-probabilities, and gradient-non-None (finite) assertions.

What This Part Should Cover

  • Invariant-driven tests: each test asserts a defining property of correctness (rows sum to 1, future mass is zero, grads are finite) rather than a hard-coded expected value.
  • Discriminating power: the tests genuinely fail on the buggy code and pass on the fix — the candidate can say which bug each test traps.
  • Determinism awareness: controls dropout / module mode so probability-sum assertions are reproducible.

Part D — Complexity analysis & memory reduction

Derive the time and memory complexity of scaled dot-product attention in terms of batch BBB, heads HHH, sequence length SSS, and head dimension dkd_kdk​. Then propose two changes that reduce memory without degrading perplexity too much, explaining the trade-offs of each.

What This Part Should Cover

  • Correct asymptotics: identifies the S×SS \times SS×S score tensor as the dominant term and states the Θ(B H S2 dk)\Theta(B\,H\,S^2\,d_k)Θ(BHS2dk​) time / Θ(B H S2)\Theta(B\,H\,S^2)Θ(BHS2) memory scaling.
  • Targeted reductions: each proposed change is aimed at a named cost, with the memory win separated from any quality cost, and is correctly labelled exact vs. approximate .
  • Trade-off fluency: credible discussion of where each technique applies (training vs. inference / decoding) and its limitations.

What a Strong Answer Covers

These dimensions span all four parts:

  • Two-failure-class framing throughout: consistently distinguishes learning-blocking bugs from numerics-corrupting bugs when reasoning across Parts A–C.
  • Prioritization: surfaces the single highest-impact bug first (the one that makes the loss flat regardless of everything else) rather than treating all defects as equal.
  • Internal consistency: the corrected code (Part B), the tests (Part C), and the complexity claims (Part D) all agree on the same attention layout and shapes.
  • Communication under time pressure: explains reasoning clearly and concisely, naming the mechanism behind each fix rather than only the patch.

Follow-up Questions

  • Which single bug, fixed alone, would move the loss the most — and which would change nothing on its own?
  • How would you adapt this model and training loop to handle variable-length, padded sequences in a batch (mask construction, loss masking, padding_idx )?
  • At 100x the sequence length, what breaks first, and what is the first change you would make?
  • How would your answer to Part D change for inference / decoding (KV cache) versus training?
  • How would you detect and guard against NaN / inf in a long training run (mixed precision, loss scaling, gradient clipping)?
Loading comments...

Browse More Questions

More Machine Learning•More OpenAI•More Data Scientist•OpenAI Data Scientist•OpenAI Machine Learning•Data Scientist Machine Learning

Write your answer

Your first approved answer each day earns 20 XP.

Sign in to write your answer.
PracHub

Master your tech interviews with 8,000+ real questions from top companies.

Product

  • Questions
  • Learning Tracks
  • Interview Guides
  • Resources
  • Premium
  • For Universities
  • Student Access

Browse

  • By Company
  • By Role
  • By Category
  • Topic Hubs
  • SQL Questions
  • AI Coding Questions
  • Compare Platforms
  • Discord Community

Support

  • support@prachub.com
  • (916) 541-4762

Legal

  • Privacy Policy
  • Terms of Service
  • About Us

© 2026 PracHub. All rights reserved.