This question evaluates a candidate's skills in implementing and debugging a Transformer-based causal language model in PyTorch, including model architecture correctness, attention masking, numerical stability, training-loop hygiene, unit tests, and analysis of attention time/memory complexity.
You’re given a minimal causal language model in PyTorch that “trains” but the loss never improves and sometimes becomes NaN. Identify every bug and provide a corrected implementation. For each bug, explain: (a) the symptom it causes, (b) the root cause, and (c) the minimal code fix. Then: (1) write 3 unit tests that would have caught these issues before training, (2) derive the time and memory complexity of scaled dot-product attention in terms of batch B, heads H, sequence length S, and head dimension d_k, and (3) propose two changes that reduce memory without degrading perplexity too much, explaining trade-offs. Faulty code:
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
class TinyDecoder(nn.Module):
def __init__(self, vocab_size, d_model=64, n_heads=4):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.tok = nn.Embedding(vocab_size, d_model) # BUG: no padding_idx set
self.pos = nn.Parameter(torch.zeros(d_model)) # BUG: single vector used for all positions
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
self.ln = nn.LayerNorm(d_model)
self.out = nn.Linear(d_model, vocab_size)
self.drop = nn.Dropout(0.2)
def forward(self, x, attn_mask=None):
B, S = x.shape
h = self.tok(x) + self.pos # BUG: broadcasts same pos to every time step
qkv = self.qkv(h)
q, k, v = qkv.chunk(3, dim=-1)
H = self.n_heads
q = q.view(B, S, H, -1)
k = k.view(B, S, H, -1)
v = v.view(B, S, H, -1)
attn = torch.matmul(q, k.transpose(-2, -1)) # BUG: missing 1/sqrt(d_k) scale
if attn_mask is not None:
attn = attn.masked_fill(attn_mask == 0, -1e9) # BUG: mask shape/dtype likely wrong for heads
w = F.softmax(attn, dim=0) # BUG: softmax on dim 0, should be last dim
z = torch.matmul(w, v).view(B, S, -1)
h2 = self.proj(z)
h3 = self.ln(h + self.drop(h2)) # residual then LN is ok but training mode matters
return self.out(h3).softmax(-1) # BUG: softmax before CE loss
def train():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TinyDecoder(vocab_size=100).to('cpu') # BUG: device mismatch
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.eval() # BUG: wrong mode; disables dropout, affects LN stats
for step in range(200):
x = torch.randint(1, 100, (32, 16), device=device)
y = x # BUG: targets not next-token shifted
logits = model(x) # BUG: no causal mask passed
loss = F.cross_entropy(F.log_softmax(logits, 2), # BUG: applying log_softmax before CE
y.float()) # BUG: targets are float; CE expects Long
with torch.no_grad():
loss.backward() # BUG: backward under no_grad prevents grads
opt.step() # BUG: missing optimizer.zero_grad()
if step % 50 == 0:
print(step, loss.item())
train()
"""
Deliverables: (A) A bullet list of at least 12 distinct defects across architecture, masking, numerics, and the training loop; (B) corrected, runnable code that trains to decreasing loss on random next-token data (use a proper causal mask, shift targets, correct tensor shapes/dtypes/devices, and avoid redundant softmax); (C) three PyTest-style unit tests (e.g., shape/contiguity checks, mask correctness, gradient non-None assertions) that fail on the buggy code and pass on your fix; (D) complexity derivation for attention and two memory-reduction strategies (e.g., FlashAttention, sequence chunking, low-rank KV cache) with pros/cons.
Quick Answer: This question evaluates a candidate's skills in implementing and debugging a Transformer-based causal language model in PyTorch, including model architecture correctness, attention masking, numerical stability, training-loop hygiene, unit tests, and analysis of attention time/memory complexity.
You are given a tiny causal decoder-only language model implemented in PyTorch. It appears to "train" but the loss does not improve and sometimes becomes NaN. The provided code contains multiple issues spanning model architecture, attention masking, numerics, and the training loop.
Your task is to identify all defects, explain their impact, and provide a corrected, runnable implementation that demonstrates a decreasing loss on a simple next-token prediction task. You will also write unit tests that would have caught the issues early and analyze the time/memory complexity of scaled dot-product attention, plus memory-reduction strategies.
This is a debugging-under-time-pressure exercise (think 60 minutes). It tests your knowledge of Transformer architecture, your ability to trace a symptom back to its root cause, and your attention to detail. Treat the inline # BUG: comments as hints, not as an exhaustive list — read the code critically and find issues for yourself.
Faulty Code
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
class TinyDecoder(nn.Module):
def __init__(self, vocab_size, d_model=64, n_heads=4):
super().__init__()
self.vocab_size = vocab_size
self.d_model = d_model
self.n_heads = n_heads
self.tok = nn.Embedding(vocab_size, d_model) # BUG: no padding_idx set
self.pos = nn.Parameter(torch.zeros(d_model)) # BUG: single vector for all positions
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
self.ln = nn.LayerNorm(d_model)
self.out = nn.Linear(d_model, vocab_size)
self.drop = nn.Dropout(0.2)
def forward(self, x, attn_mask=None):
B, S = x.shape
h = self.tok(x) + self.pos # BUG: broadcasts same pos to every step
qkv = self.qkv(h)
q, k, v = qkv.chunk(3, dim=-1)
H = self.n_heads
q = q.view(B, S, H, -1)
k = k.view(B, S, H, -1)
v = v.view(B, S, H, -1)
attn = torch.matmul(q, k.transpose(-2, -1)) # BUG: missing 1/sqrt(d_k) scale
if attn_mask is not None:
attn = attn.masked_fill(attn_mask == 0, -1e9) # BUG: mask shape/dtype likely wrong for heads
w = F.softmax(attn, dim=0) # BUG: softmax on dim 0, should be last dim
z = torch.matmul(w, v).view(B, S, -1)
h2 = self.proj(z)
h3 = self.ln(h + self.drop(h2)) # residual then LN is ok but training mode matters
return self.out(h3).softmax(-1) # BUG: softmax before CE loss
def train():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TinyDecoder(vocab_size=100).to('cpu') # BUG: device mismatch
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.eval() # BUG: wrong mode; disables dropout, affects LN
for step in range(200):
x = torch.randint(1, 100, (32, 16), device=device)
y = x # BUG: targets not next-token shifted
logits = model(x) # BUG: no causal mask passed
loss = F.cross_entropy(F.log_softmax(logits, 2), # BUG: applying log_softmax before CE
y.float()) # BUG: targets are float; CE expects Long
with torch.no_grad():
loss.backward() # BUG: backward under no_grad prevents grads
opt.step() # BUG: missing optimizer.zero_grad()
if step % 50 == 0:
print(step, loss.item())
train()
"""
Constraints & Assumptions
Framework:
Python 3 + PyTorch 2.x (
torch
,
torch.nn
,
torch.nn.functional
). You may assume a recent PyTorch is available.
Model scale (toy):vocab_size = 100
,
d_model = 64
,
n_heads = 4
, batch
B = 32
, sequence length
S = 16
. Keep the architecture decoder-only and small enough to run on CPU.
Objective:
standard autoregressive next-token prediction with cross-entropy loss.
Determinism:
assume
torch.manual_seed(0)
is set so a fixed run is reproducible.
Goal of "working":
the corrected loop must show a clearly
decreasing
loss curve on a task that actually contains learnable signal; it should run on CPU in seconds, with no
NaN
/
inf
.
Clarifying Questions to Ask
Should I treat the inline
# BUG:
comments as the complete, authoritative list, or are there additional defects that are not annotated?
Is the goal a
minimal
fix per bug (preserving the existing structure), or am I free to refactor the module for clarity?
Is it acceptable to invent a small deterministic toy task to demonstrate a falling loss, or must I train on the original i.i.d. random data?
Are CPU-only results acceptable for the demonstration, or do you want it to run on GPU?
For the complexity analysis, do you want just the attention block, or the full forward pass including the linear / FFN projections?
How rigorous should the unit tests be — illustrative assertions, or genuinely runnable PyTest functions?
Part A — Defect inventory (find at least 12)
Produce a bullet list of at least 12 distinct defects, grouped by subsystem (model architecture / forward math, attention masking & numerics, and the training loop). For each defect give: (a) the symptom it causes, (b) the root cause, and (c) the minimal code fix.
What This Part Should Cover
Coverage & grouping:
finds comfortably past the minimum count and organizes defects by subsystem rather than as a flat dump.
Symptom → root cause → fix discipline:
for each defect, ties the
observed behavior
to the
mechanism
and gives a
minimal, correct
change.
The non-crashing bugs:
catches the silent ones (wrong attention axis, wrong softmax dimension, double softmax) — not just the ones that throw.
Two failure classes:
separates bugs that
block learning
from bugs that
corrupt numerics
, and recognizes both must be fixed.
Part B — Corrected, runnable implementation
Provide corrected, runnable code that trains to a decreasing loss on a simple next-token task. It must use a proper causal mask, correct next-token shifting, correct tensor shapes / dtypes / devices, and no redundant softmax. Show (or describe) a loss curve that clearly goes down.
What This Part Should Cover
Correct attention mechanics:
head layout (heads on the batch axis),
1/dk
scaling, softmax axis, causal masking, and shape bookkeeping all correct in the rewrite.
A learnable demonstration:
understands that a flat loss can mean
either
a bug
or
an unlearnable task, and constructs data that actually proves the fixes took.
End-to-end wiring:
correct device placement,
train()
mode,
zero_grad
, next-token shift, and raw-logit cross-entropy with
long
targets.
Part C — Unit tests
Write three PyTest-style unit tests that would fail on the buggy code and pass on your fix. Good targets include: attention shape / contiguity, softmax-axis and mask/causality correctness, raw-logits-vs-probabilities, and gradient-non-None (finite) assertions.
What This Part Should Cover
Invariant-driven tests:
each test asserts a
defining property
of correctness (rows sum to 1, future mass is zero, grads are finite) rather than a hard-coded expected value.
Discriminating power:
the tests genuinely fail on the buggy code and pass on the fix — the candidate can say
which
bug each test traps.
Determinism awareness:
controls dropout / module mode so probability-sum assertions are reproducible.
Part D — Complexity analysis & memory reduction
Derive the time and memory complexity of scaled dot-product attention in terms of batch B, heads H, sequence length S, and head dimension dk. Then propose two changes that reduce memory without degrading perplexity too much, explaining the trade-offs of each.
What This Part Should Cover
Correct asymptotics:
identifies the
S×S
score tensor as the dominant term and states the
Θ(BHS2dk)
time /
Θ(BHS2)
memory scaling.
Targeted reductions:
each proposed change is aimed at a named cost, with the
memory
win separated from any
quality
cost, and is correctly labelled
exact
vs.
approximate
.
Trade-off fluency:
credible discussion of where each technique applies (training vs. inference / decoding) and its limitations.
What a Strong Answer Covers
These dimensions span all four parts:
Two-failure-class framing throughout:
consistently distinguishes learning-blocking bugs from numerics-corrupting bugs when reasoning across Parts A–C.
Prioritization:
surfaces the single highest-impact bug first (the one that makes the loss flat regardless of everything else) rather than treating all defects as equal.
Internal consistency:
the corrected code (Part B), the tests (Part C), and the complexity claims (Part D) all agree on the same attention layout and shapes.
Communication under time pressure:
explains reasoning clearly and concisely, naming the mechanism behind each fix rather than only the patch.
Follow-up Questions
Which single bug, fixed alone, would move the loss the most — and which would change nothing on its own?
How would you adapt this model and training loop to handle
variable-length, padded
sequences in a batch (mask construction, loss masking,
padding_idx
)?
At 100x the sequence length, what breaks first, and what is the first change you would make?
How would your answer to Part D change for
inference / decoding
(KV cache) versus training?
How would you detect and guard against
NaN
/
inf
in a long training run (mixed precision, loss scaling, gradient clipping)?