This question evaluates practical debugging skills for Transformer-based sequence models, targeting machine learning engineers who must diagnose intermittent shape/dtype errors and convergence failures simultaneously. It tests deep familiarity with the full training stack — data pipelines, attention masking, positional encoding, optimizer precision, and distributed setup — and the ability to reason about how a single root cause can manifest across multiple symptoms.
A Transformer-based sequence model intermittently throws shape/dtype mismatch errors and fails to converge after several thousand steps. Describe your end-to-end debugging approach. Include: how you validate tokenization, padding, and special tokens; verify attention masks (causal vs. bidirectional) and positional embeddings; add assertions and unit tests for tensor shapes and sequence lengths; pinpoint exploding/vanishing gradients (e.g., gradient norms, clipping, optimizer/betas, LR schedule, mixed precision); isolate data bugs (truncation, BOS/EOS handling, label shifting for LM objectives); diagnose loss NaNs and numerical instability; check checkpoint/resume logic and randomness/seed control; debug multi-GPU/DP/ZeRO issues (sync, grad scaling, AMP); profile throughput and memory (operator-level hotspots, OOM sources); and construct a minimal reproducible example with targeted logging to localize the fault. Provide concrete checks, metrics, and code-level instrumentation you would add.
Quick Answer: This question evaluates practical debugging skills for Transformer-based sequence models, targeting machine learning engineers who must diagnose intermittent shape/dtype errors and convergence failures simultaneously. It tests deep familiarity with the full training stack — data pipelines, attention masking, positional encoding, optimizer precision, and distributed setup — and the ability to reason about how a single root cause can manifest across multiple symptoms.
Debugging a Transformer That Intermittently Throws Shape/Dtype Errors and Fails to Converge
You inherit a Transformer-based sequence model (decoder-only language-model objective) whose training run exhibits two symptoms simultaneously:
It
intermittently
raises shape or dtype mismatch errors during training — most batches succeed, then one fails, often only after the run has been going for a while.
It
fails to converge
after several thousand steps: the loss stalls on a plateau, oscillates, or diverges (sometimes spiking to
NaN
).
Walk through your end-to-end debugging strategy. You should describe the order in which you would attack the problem, the concrete checks, metrics, and code-level instrumentation you would add at each stage, and how each check lets you confirm or rule out a subsystem (data, batching/collation, masking, positional encoding, the model forward pass, the optimizer/precision stack, checkpointing, and the distributed setup). Treat the two symptoms as possibly related (e.g. a silently-malformed batch could both break shapes and poison the loss) and explain how you would tell whether they share a root cause.
Constraints & Assumptions
This is a
training
problem (the convergence symptom rules out a pure inference issue), but the candidate may use generation/KV-cache reasoning to explain the shape crashes.
Assume access to the full source: model code, data pipeline, training loop, and config. You can add instrumentation and run experiments freely.
Assume a typical PyTorch stack:
DataLoader
+ custom collate, AMP/
GradScaler
mixed precision, AdamW, an LR scheduler with warmup, and the option of multi-GPU (DDP and/or ZeRO/DeepSpeed). Do
not
assume any specific finding is already known — the candidate must
discover
the root cause.
"Intermittent" means most steps succeed; the failure correlates with
some
batches or with elapsed time/state, not with every step.
Clarifying Questions to Ask
A strong candidate scopes the problem before diving in. Reasonable questions include:
Is this a
fresh failure
or a regression? Did it start after a code change, a data refresh, a tokenizer change, or a hardware/library upgrade?
Does the shape/dtype error reproduce on a
single GPU and small batch
, or only at scale / under DDP / under AMP?
What is the model objective and architecture precisely — decoder-only causal LM? Absolute, sinusoidal, or
rotary
positional embeddings? Is a
KV cache
used during training (e.g. chunked / streaming)?
Are the loss plateau and the crashes
correlated in time
(same step / same batch), or independent symptoms?
What does the loss curve actually look like — flat from step 0, slow decay then stall, or healthy-then-divergence? What is the current LR schedule and precision setting?
Is the tokenizer / dataset
versioned and identical
across all workers and any resumed checkpoint?
What a Strong Answer Covers
A strong answer is a systematic, ordered triage — not a flat list — that converts intermittent failures into deterministic, localized assertions. The interviewer is looking for breadth across the stack and depth on the high-yield culprits, with concrete code/metrics. Dimensions to look for:
Reproducibility & failure capture
— seed control, deterministic flags,
set_detect_anomaly
, logging the failing batch index/inputs/device, and a per-step metrics panel (loss, grad norm, LR, scaler scale + overflow count).
Tokenization, padding & special tokens
— single pinned tokenizer across workers/checkpoints, round-trip checks, correct
PAD
/
BOS
/
EOS
/
UNK
ids, padding side consistency, and sane sequence-length distributions (not silently truncated).
Collation invariants
— right/left padding consistency;
attention_mask
= 1 for real tokens, 0 for
PAD
;
labels
=
ignore_index
on
PAD
; runtime asserts on the collate output.
Attention masks
— causal vs. bidirectional correctness, mask shape/dtype, composing causal + padding masks, and masking in
float32
before softmax (not large negatives in FP16).
Positional embeddings
—
position_ids
within table bounds; for RoPE, consistent Q/K positions and correct
past_len
offset with a KV cache; behavior at/over max context length.
LM targets / label shifting
— next-token shift correctness, no off-by-one,
ignore_index
on PAD, and a sanity check that the
ignore_index
fraction tracks the padding rate.
Gradients, optimizer & LR
— global and per-layer grad norms, clipping, warmup, AdamW
betas
/
eps
, and
excluding bias/LayerNorm from weight decay
; distinguishing exploding vs. vanishing grads from the curves.
Numerical stability & NaNs
—
assert_finite
on logits/loss, float32 softmax/loss, epsilon choice, and reading the
GradScaler
(frequent overflow/step-skipping as a signal).
Data bugs
— shard/tokenizer consistency, all-
UNK
/ all-
PAD
/ pathologically long samples, truncation that drops
EOS
or breaks label shifting.
Checkpoint/resume & RNG
— saving/restoring model + optimizer + scheduler + scaler + sampler + global step; verifying loss continuity and per-rank deterministic sampler seeding after resume.
Multi-GPU / DDP / ZeRO / AMP
— identical per-rank batch shapes, exactly-once gradient reduction,
find_unused_parameters
hygiene, framework-correct grad clipping under ZeRO, and per-rank logging.
Throughput & memory profiling
— tokens/sec, dataloader-vs-compute split,
torch.profiler
hotspots,
memory_summary
, and OOM triage (batch size, gradient checkpointing, padding to multiples of 8/16).
Minimal reproducible example
— a tiny deterministic synthetic-task setup used to bisect model/optimizer faults from data/distributed faults, plus a symptom→root-cause mapping.
Follow-up Questions
The crash only appears
after ~3,000 steps
and never in a fresh MRE. What classes of bugs become
more
likely with elapsed time/state, and how would you confirm one (e.g. KV-cache/position drift, scaler state, optimizer state, or a rare long sample that only appears late in an epoch)?
Your
GradScaler
is skipping a large fraction of steps. What does that tell you, and what is your remediation order before you reach for "turn off AMP"?
Suppose the MRE on synthetic data
converges fine
but the real run still stalls. Which subsystems does that exonerate, and where do you look next?
How would you write a
regression test
that catches a future reintroduction of the label-shift / mask-dtype bug in CI, so this class of failure can't silently return?