Transformer Training Pipeline Debugging
Asked of: ML Engineer
Last updated

What's being tested
Candidates must show practical fluency debugging the entire Transformer training pipeline: data/tokenization correctness, attention and positional masking semantics, loss-to-logit alignment, and optimizer/numerical stability. Interviewers probe whether you can triage a failing run (diverging loss, NaNs, stalled training) by isolating data, model, and optimizer failure modes and propose low-risk fixes and diagnostic instrumentation. OpenAI expects an ML Engineer to demonstrate reproducible debugging steps, safe mitigations (e.g., mixed-precision fixes), and an understanding of tradeoffs for training stability at scale.
Core knowledge
-
Tokenization & padding: Ensure
pad_token_idis consistent between tokenizer and model; padding positions must be masked in attention and loss computation to avoid label leakage and inflated metrics. -
Attention masks: Distinguish causal mask (upper triangular for autoregressive) versus padding mask (per-token). Wrong shape/broadcasting is a common silent bug.
-
Positional embeddings: Off-by-one or mismatched max length produces shifted predictions; check sequence lengths, truncation, and embedding table size.
-
Loss alignment: For next-token prediction, mask the loss over padding and optionally over the input prompt; verify logits index ordering and whether labels are shifted by one (teacher forcing).
-
Optimization details: AdamW hyperparameters (
lr,betas,eps,weight_decay) and learning-rate schedulers (warmup, cosine) interact with stability; use linear warmup and conservative lr when diagnosing divergence. -
Mixed-precision training: With AMP/
torch.cuda.amp, use dynamic loss scaling to avoid underflow/overflow; NaNs often mean overflow—try increasing loss scale or temporarily use full precision. -
Gradient issues: Monitor gradient norms and distribution; apply gradient clipping (global norm) and check for stale zero grads due to
requires_grad=Falseor incorrectoptimizer.zero_grad()placement. -
Parameter initialization & tying: Bad initialization (too large) or forgetting to tie embedding and output projection weights can slow convergence; confirm initialization schemes (Xavier/Kaiming) match activation functions.
-
Numerical checks: Log
max/min/meanof activations, logits, and gradients per layer; NaN/Inf propagation often originates in softmax exponentials or invalid layer outputs. For softmax, clipping logits prevents overflow. -
Metrics & evaluation parity: Ensure offline training metric (masked token loss) matches online eval definition; metric leakage arises from using unmasked losses or different vocab indices.
-
Distributed training pitfalls: In
DDP, ensuresync_gradientssemantics, same random seed across workers, and consistent batch-norm/LayerNorm behavior; micro-batch accumulation affects effective lr—scale lr with batch size using linear-scaling rule. -
Instrumentation & reproducibility: Automate sanity checks: deterministic single-batch forward/backward, parameter checksum snapshots, and unit tests for masking shapes and label shifts.
Tip: When hunting NaNs, reproduce the issue on a single GPU/CPU with a single batch — it's faster and isolates distributed nondeterminism.
Worked example — Debug a transformer training pipeline
Frame the problem: ask whether the failure is reproducible, what the observed symptom is (divergent loss, NaNs, stuck at pretraining loss), and what changed recently (code, data, hyperparams). Declare assumptions: using PyTorch with AdamW and AMP, objective is next-token prediction.
Skeleton of the investigation:
-
Reproduce on a single deterministic batch and run forward/backward to see where NaNs/Inf appear.
-
Verify preprocessing: tokenization,
pad_token_id, label shifting, and that loss masking excludes padding. -
Inspect model internals: attention masks shape and broadcasting, positional embedding range, LayerNorm behavior.
-
Check optimizer state and mixed-precision loss scaling; run with FP32 temporarily.
-
Monitor gradient norms, activation ranges, and weight updates.
Tradeoff to flag: switching to FP32 fixes numerical issues but doubles memory and slows training; using dynamic loss scaling with AMP is cheaper but may mask underlying exploding gradients. Conclude by proposing low-risk fixes first (zero warmup lr, clamp logits, increase loss scale) and longer-term fixes (add gradient clipping, adjust initialization). If more time: add unit tests for mask shapes, implement automated single-batch regression tests, and run a minimal integration test in CI.
A second angle — Debug a broken Transformer implementation
This variant focuses more on implementation correctness than training stability. Start by confirming API contracts (encoder/decoder shapes, mask semantics). Use deterministic unit tests: pass synthetic inputs where attention output is analytically predictable (e.g., identity-like attention) to validate masking and softmax numerics. Check parameter initialization and layer ordering (pre-LN vs post-LN) — a swapped LayerNorm/Residual order changes optimization dynamics. Also validate weight tying and embedding-output projection dimensions; a mismatch can silently broadcast, producing wrong logits. Here the emphasis is on small-scale functional tests before scaling to full training, whereas the training-pipeline case emphasizes numerical and data-driven failure modes.
Common pitfalls
Pitfall: Blaming the optimizer first. Many candidates immediately tune
lrorbetaswithout validating data and masking; this wastes time—first reproduce on a single batch and confirm loss-mask-label alignment.
Pitfall: Overlooking label shift. A tempting quick fix is to change loss normalization; the correct check is whether labels are shifted relative to logits (teacher forcing). Missing this yields plausible-looking but wrong training signals.
Pitfall: Hiding issues behind AMP. Switching to full precision fixes NaNs but can hide root causes like exploding activations or incorrect masking; document this as a diagnostic step, not a permanent solution.
Connections
Interviewers may pivot to adjacent topics like distributed training (gradient accumulation, DDP, bucket sizes), model serving (exporting attention masks to ONNX/TensorRT), or observability (production telemetry for drift and per-token loss). Being ready to discuss how debugging at training time changes monitoring needs in inference is useful.
Further reading
-
Attention Is All You Need (Vaswani et al., 2017) — foundational architecture and masking semantics.
-
Mixed Precision Training (Micikevicius et al., 2018) — practical techniques for AMP and loss scaling.
Practice questions
- Debug a broken Transformer implementationOpenAI · Machine Learning Engineer · Onsite · hard
- Debug a transformer training pipelineOpenAI · Machine Learning Engineer · Technical Screen · hard
- Diagnose Transformer training and inference bugsOpenAI · Machine Learning Engineer · Technical Screen · hard
- Debug transformer and train classifierOpenAI · Machine Learning Engineer · Technical Screen · hard
- Debug a transformer training pipelineOpenAI · Machine Learning Engineer · Technical Screen · hard
Related concepts
- Transformer Attention And MaskingMachine Learning
- Supervised ML Workflows, Interpretability And DeploymentMachine Learning
- PyTorch Training And Model ImplementationCoding & Algorithms
- Distributed Training Parallelism And CollectivesML System Design
- Production ML Pipelines And System DesignML System Design
- Machine Learning Project LifecycleMachine Learning