Debug and fix a PyTorch Transformer training loop
Company: OpenAI
Role: Data Scientist
Category: Machine Learning
Difficulty: hard
Interview Round: Onsite
You’re given a minimal causal language model in PyTorch that “trains” but the loss never improves and sometimes becomes NaN. Identify every bug and provide a corrected implementation. For each bug, explain: (a) the symptom it causes, (b) the root cause, and (c) the minimal code fix. Then: (1) write 3 unit tests that would have caught these issues before training, (2) derive the time and memory complexity of scaled dot-product attention in terms of batch B, heads H, sequence length S, and head dimension d_k, and (3) propose two changes that reduce memory without degrading perplexity too much, explaining trade-offs. Faulty code:
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
class TinyDecoder(nn.Module):
def __init__(self, vocab_size, d_model=64, n_heads=4):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.tok = nn.Embedding(vocab_size, d_model) # BUG: no padding_idx set
self.pos = nn.Parameter(torch.zeros(d_model)) # BUG: single vector used for all positions
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
self.ln = nn.LayerNorm(d_model)
self.out = nn.Linear(d_model, vocab_size)
self.drop = nn.Dropout(0.2)
def forward(self, x, attn_mask=None):
B, S = x.shape
h = self.tok(x) + self.pos # BUG: broadcasts same pos to every time step
qkv = self.qkv(h)
q, k, v = qkv.chunk(3, dim=-1)
H = self.n_heads
q = q.view(B, S, H, -1)
k = k.view(B, S, H, -1)
v = v.view(B, S, H, -1)
attn = torch.matmul(q, k.transpose(-2, -1)) # BUG: missing 1/sqrt(d_k) scale
if attn_mask is not None:
attn = attn.masked_fill(attn_mask == 0, -1e9) # BUG: mask shape/dtype likely wrong for heads
w = F.softmax(attn, dim=0) # BUG: softmax on dim 0, should be last dim
z = torch.matmul(w, v).view(B, S, -1)
h2 = self.proj(z)
h3 = self.ln(h + self.drop(h2)) # residual then LN is ok but training mode matters
return self.out(h3).softmax(-1) # BUG: softmax before CE loss
def train():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TinyDecoder(vocab_size=100).to('cpu') # BUG: device mismatch
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.eval() # BUG: wrong mode; disables dropout, affects LN stats
for step in range(200):
x = torch.randint(1, 100, (32, 16), device=device)
y = x # BUG: targets not next-token shifted
logits = model(x) # BUG: no causal mask passed
loss = F.cross_entropy(F.log_softmax(logits, 2), # BUG: applying log_softmax before CE
y.float()) # BUG: targets are float; CE expects Long
with torch.no_grad():
loss.backward() # BUG: backward under no_grad prevents grads
opt.step() # BUG: missing optimizer.zero_grad()
if step % 50 == 0:
print(step, loss.item())
train()
"""
Deliverables: (A) A bullet list of at least 12 distinct defects across architecture, masking, numerics, and the training loop; (B) corrected, runnable code that trains to decreasing loss on random next-token data (use a proper causal mask, shift targets, correct tensor shapes/dtypes/devices, and avoid redundant softmax); (C) three PyTest-style unit tests (e.g., shape/contiguity checks, mask correctness, gradient non-None assertions) that fail on the buggy code and pass on your fix; (D) complexity derivation for attention and two memory-reduction strategies (e.g., FlashAttention, sequence chunking, low-rank KV cache) with pros/cons.
Quick Answer: This question evaluates a candidate's skills in implementing and debugging a Transformer-based causal language model in PyTorch, including model architecture correctness, attention masking, numerical stability, training-loop hygiene, unit tests, and analysis of attention time/memory complexity.