Debug a transformer training pipeline
Company: OpenAI
Role: Machine Learning Engineer
Category: Machine Learning
Difficulty: hard
Interview Round: Technical Screen
##### Question
You are given a PyTorch Transformer-based training pipeline (a multi-head attention encoder-decoder, with tokenization, padding/masking, AdamW with weight decay, gradient clipping, a cosine LR scheduler with warmup, mixed precision, and teacher forcing) that misbehaves. The training loss decreases or even diverges while validation accuracy stays near chance, and the run also exhibits intermittent crashes and irreproducible results. Concretely, the pipeline shows the following symptoms:
1. Occasional CUDA shape/`IndexError`s around attention masks.
2. Validation metrics plateau near chance even though training loss decreases (and in some runs the loss diverges entirely).
3. Intermittent crashes under mixed precision (AMP) combined with gradient accumulation.
4. Nondeterministic results across runs with the same config.
Answer the following:
1. Identify at least four distinct, plausible root-cause bugs behind these symptoms. Draw from the standard failure surface, e.g.: incorrect attention/causal mask construction (shape, dtype, device, or polarity); off-by-one teacher-forcing label shift and PAD tokens not ignored in the loss; missing or mis-placed `optimizer.zero_grad` (especially with gradient accumulation); misplaced gradient scaling / clipping before unscale under AMP; wrong positional-encoding shape or broadcast (or position indices exceeding the embedding table); incorrectly excluding LayerNorm/bias from weight decay; and data leakage in batching.
2. For each bug, explain the failure mode (the signal you would look for), show a minimal code fix, and write a unit test or runtime assertion that would catch it and prevent regressions.
3. Lay out a systematic debugging plan that localizes each fault component-by-component. Give concrete checks and experiments for: data preprocessing (padding, truncation, label alignment); tokenization and attention/causal masks; positional encodings; loss computation (`ignore_index`, label smoothing, class weights); optimizer/scheduler/`zero_grad`/gradient clipping; AMP/`GradScaler` settings; seed control and deterministic kernels; and DDP/Sampler configuration.
4. Specify the sanity checks (e.g., overfit a tiny subset / copy task, gradient and activation statistics, NaN/Inf detection), the metrics you would monitor, and a small experiment to verify each fix.
Quick Answer: An OpenAI ML engineer technical-screen debugging question: given a PyTorch Transformer training pipeline whose loss diverges and validation accuracy stays near chance (plus mask IndexErrors, AMP crashes, and nondeterminism), identify the distinct bugs and fix them. It tests Transformer internals - attention/causal masking, teacher-forcing label shift, positional encodings, ignore_index loss, optimizer and GradScaler order, determinism, and DDP - along with writing unit tests and a structured debugging plan.