PracHub
QuestionsPremiumLearningGuidesCheatsheetNEWCoaches
|Home/Machine Learning/OpenAI

Debug and fix a PyTorch Transformer training loop

Last updated: May 13, 2026

Quick Overview

This question evaluates a candidate's skills in implementing and debugging a Transformer-based causal language model in PyTorch, including model architecture correctness, attention masking, numerical stability, training-loop hygiene, unit tests, and analysis of attention time/memory complexity.

  • hard
  • OpenAI
  • Machine Learning
  • Data Scientist

Debug and fix a PyTorch Transformer training loop

Company: OpenAI

Role: Data Scientist

Category: Machine Learning

Difficulty: hard

Interview Round: Onsite

You’re given a minimal causal language model in PyTorch that “trains” but the loss never improves and sometimes becomes NaN. Identify every bug and provide a corrected implementation. For each bug, explain: (a) the symptom it causes, (b) the root cause, and (c) the minimal code fix. Then: (1) write 3 unit tests that would have caught these issues before training, (2) derive the time and memory complexity of scaled dot-product attention in terms of batch B, heads H, sequence length S, and head dimension d_k, and (3) propose two changes that reduce memory without degrading perplexity too much, explaining trade-offs. Faulty code: """ import torch import torch.nn as nn import torch.nn.functional as F torch.manual_seed(0) class TinyDecoder(nn.Module): def __init__(self, vocab_size, d_model=64, n_heads=4): super().__init__() self.vocab_size = vocab_size self.d_model = d_model self.n_heads = n_heads self.tok = nn.Embedding(vocab_size, d_model) # BUG: no padding_idx set self.pos = nn.Parameter(torch.zeros(d_model)) # BUG: single vector used for all positions self.qkv = nn.Linear(d_model, 3 * d_model) self.proj = nn.Linear(d_model, d_model) self.ln = nn.LayerNorm(d_model) self.out = nn.Linear(d_model, vocab_size) self.drop = nn.Dropout(0.2) def forward(self, x, attn_mask=None): B, S = x.shape h = self.tok(x) + self.pos # BUG: broadcasts same pos to every time step qkv = self.qkv(h) q, k, v = qkv.chunk(3, dim=-1) H = self.n_heads q = q.view(B, S, H, -1) k = k.view(B, S, H, -1) v = v.view(B, S, H, -1) attn = torch.matmul(q, k.transpose(-2, -1)) # BUG: missing 1/sqrt(d_k) scale if attn_mask is not None: attn = attn.masked_fill(attn_mask == 0, -1e9) # BUG: mask shape/dtype likely wrong for heads w = F.softmax(attn, dim=0) # BUG: softmax on dim 0, should be last dim z = torch.matmul(w, v).view(B, S, -1) h2 = self.proj(z) h3 = self.ln(h + self.drop(h2)) # residual then LN is ok but training mode matters return self.out(h3).softmax(-1) # BUG: softmax before CE loss def train(): device = 'cuda' if torch.cuda.is_available() else 'cpu' model = TinyDecoder(vocab_size=100).to('cpu') # BUG: device mismatch opt = torch.optim.AdamW(model.parameters(), lr=1e-3) model.eval() # BUG: wrong mode; disables dropout, affects LN stats for step in range(200): x = torch.randint(1, 100, (32, 16), device=device) y = x # BUG: targets not next-token shifted logits = model(x) # BUG: no causal mask passed loss = F.cross_entropy(F.log_softmax(logits, 2), # BUG: applying log_softmax before CE y.float()) # BUG: targets are float; CE expects Long with torch.no_grad(): loss.backward() # BUG: backward under no_grad prevents grads opt.step() # BUG: missing optimizer.zero_grad() if step % 50 == 0: print(step, loss.item()) train() """ Deliverables: (A) A bullet list of at least 12 distinct defects across architecture, masking, numerics, and the training loop; (B) corrected, runnable code that trains to decreasing loss on random next-token data (use a proper causal mask, shift targets, correct tensor shapes/dtypes/devices, and avoid redundant softmax); (C) three PyTest-style unit tests (e.g., shape/contiguity checks, mask correctness, gradient non-None assertions) that fail on the buggy code and pass on your fix; (D) complexity derivation for attention and two memory-reduction strategies (e.g., FlashAttention, sequence chunking, low-rank KV cache) with pros/cons.

Quick Answer: This question evaluates a candidate's skills in implementing and debugging a Transformer-based causal language model in PyTorch, including model architecture correctness, attention masking, numerical stability, training-loop hygiene, unit tests, and analysis of attention time/memory complexity.

Related Interview Questions

  • Implement Backprop for a Tiny Network - OpenAI (hard)
  • Filter Bad Human Annotations - OpenAI (medium)
  • Compute Matrix Prefix Products And Gradients - OpenAI (hard)
  • Improve Training With Noisy Annotators - OpenAI (hard)
  • Debug a Broken Transformer - OpenAI (medium)
OpenAI logo
OpenAI
Oct 13, 2025, 9:49 PM
Data Scientist
Onsite
Machine Learning
206
0

Minimal Causal LM Debugging and Optimization

Context

You are given a tiny causal decoder-only language model implemented in PyTorch. It appears to "train" but the loss does not improve and sometimes becomes NaN. The provided code contains multiple issues spanning model architecture, attention masking, numerics, and the training loop.

Your task is to identify all defects, explain their impact, and provide a corrected, runnable implementation that demonstrates decreasing loss on a simple next-token prediction task. Also write unit tests that would have caught the issues early and analyze the time/memory complexity of scaled dot-product attention, plus memory-reduction strategies.

Faulty Code

""" import torch import torch.nn as nn import torch.nn.functional as F

torch.manual_seed(0)

class TinyDecoder(nn.Module): def init(self, vocab_size, d_model=64, n_heads=4): super().init() self.vocab_size = vocab_size self.d_model = d_model self.n_heads = n_heads self.tok = nn.Embedding(vocab_size, d_model) # BUG: no padding_idx set self.pos = nn.Parameter(torch.zeros(d_model)) # BUG: single vector used for all positions self.qkv = nn.Linear(d_model, 3 * d_model) self.proj = nn.Linear(d_model, d_model) self.ln = nn.LayerNorm(d_model) self.out = nn.Linear(d_model, vocab_size) self.drop = nn.Dropout(0.2)

def forward(self, x, attn_mask=None): B, S = x.shape h = self.tok(x) + self.pos # BUG: broadcasts same pos to every time step qkv = self.qkv(h) q, k, v = qkv.chunk(3, dim=-1) H = self.n_heads q = q.view(B, S, H, -1) k = k.view(B, S, H, -1) v = v.view(B, S, H, -1) attn = torch.matmul(q, k.transpose(-2, -1)) # BUG: missing 1/sqrt(d_k) scale if attn_mask is not None: attn = attn.masked_fill(attn_mask == 0, -1e9) # BUG: mask shape/dtype likely wrong for heads w = F.softmax(attn, dim=0) # BUG: softmax on dim 0, should be last dim z = torch.matmul(w, v).view(B, S, -1) h2 = self.proj(z) h3 = self.ln(h + self.drop(h2)) # residual then LN is ok but training mode matters return self.out(h3).softmax(-1) # BUG: softmax before CE loss

def train(): device = 'cuda' if torch.cuda.is_available() else 'cpu' model = TinyDecoder(vocab_size=100).to('cpu') # BUG: device mismatch opt = torch.optim.AdamW(model.parameters(), lr=1e-3) model.eval() # BUG: wrong mode; disables dropout, affects LN stats for step in range(200): x = torch.randint(1, 100, (32, 16), device=device) y = x # BUG: targets not next-token shifted logits = model(x) # BUG: no causal mask passed loss = F.cross_entropy(F.log_softmax(logits, 2), # BUG: applying log_softmax before CE y.float()) # BUG: targets are float; CE expects Long with torch.no_grad(): loss.backward() # BUG: backward under no_grad prevents grads opt.step() # BUG: missing optimizer.zero_grad() if step % 50 == 0: print(step, loss.item())

train() """

Deliverables

  • (A) A bullet list of at least 12 distinct defects with: (a) symptom, (b) root cause, (c) minimal code fix.
  • (B) Corrected, runnable code that trains to decreasing loss on a simple next-token task, using a proper causal mask, correct shifting, shapes/dtypes/devices, and no redundant softmax.
  • (C) Three PyTest-style unit tests that would fail on the buggy code and pass on your fix.
  • (D) Derive time and memory complexity of scaled dot-product attention in terms of batch B, heads H, sequence length S, head dimension d_k, and propose two memory-reduction changes with trade-offs.

Solution

Show

Comments (0)

Sign in to leave a comment

Loading comments...

Browse More Questions

More Machine Learning•More OpenAI•More Data Scientist•OpenAI Data Scientist•OpenAI Machine Learning•Data Scientist Machine Learning
PracHub

Master your tech interviews with 7,500+ real questions from top companies.

Product

  • Questions
  • Learning Tracks
  • Interview Guides
  • Resources
  • Premium
  • For Universities
  • Student Access

Browse

  • By Company
  • By Role
  • By Category
  • Topic Hubs
  • SQL Questions
  • Compare Platforms
  • Discord Community

Support

  • support@prachub.com
  • (916) 541-4762

Legal

  • Privacy Policy
  • Terms of Service
  • About Us

© 2026 PracHub. All rights reserved.