PracHub
QuestionsCoachesLearningGuidesInterview Prep
|Home/Machine Learning/OpenAI

Diagnose Transformer training and inference bugs

Last updated: Jun 25, 2026

Quick Overview

This question evaluates practical debugging skills for Transformer-based sequence models, targeting machine learning engineers who must diagnose intermittent shape/dtype errors and convergence failures simultaneously. It tests deep familiarity with the full training stack — data pipelines, attention masking, positional encoding, optimizer precision, and distributed setup — and the ability to reason about how a single root cause can manifest across multiple symptoms.

  • hard
  • OpenAI
  • Machine Learning
  • Machine Learning Engineer

Diagnose Transformer training and inference bugs

Company: OpenAI

Role: Machine Learning Engineer

Category: Machine Learning

Difficulty: hard

Interview Round: Technical Screen

A Transformer-based sequence model intermittently throws shape/dtype mismatch errors and fails to converge after several thousand steps. Describe your end-to-end debugging approach. Include: how you validate tokenization, padding, and special tokens; verify attention masks (causal vs. bidirectional) and positional embeddings; add assertions and unit tests for tensor shapes and sequence lengths; pinpoint exploding/vanishing gradients (e.g., gradient norms, clipping, optimizer/betas, LR schedule, mixed precision); isolate data bugs (truncation, BOS/EOS handling, label shifting for LM objectives); diagnose loss NaNs and numerical instability; check checkpoint/resume logic and randomness/seed control; debug multi-GPU/DP/ZeRO issues (sync, grad scaling, AMP); profile throughput and memory (operator-level hotspots, OOM sources); and construct a minimal reproducible example with targeted logging to localize the fault. Provide concrete checks, metrics, and code-level instrumentation you would add.

Quick Answer: This question evaluates practical debugging skills for Transformer-based sequence models, targeting machine learning engineers who must diagnose intermittent shape/dtype errors and convergence failures simultaneously. It tests deep familiarity with the full training stack — data pipelines, attention masking, positional encoding, optimizer precision, and distributed setup — and the ability to reason about how a single root cause can manifest across multiple symptoms.

Related Interview Questions

  • Implement 1NN with NumPy - OpenAI (medium)
  • Compute entropy and implement 1-NN - OpenAI (medium)
  • Defend a Research Direction and Experiment Design - OpenAI (medium)
  • Implement Backprop for a Tiny Network - OpenAI (hard)
  • Debug MiniGPT and Backpropagate Matmul - OpenAI (medium)
|Home/Machine Learning/OpenAI

Diagnose Transformer training and inference bugs

OpenAI logo
OpenAI
Aug 11, 2025, 12:00 AM
hardMachine Learning EngineerTechnical ScreenMachine Learning
85
0

Debugging a Transformer That Intermittently Throws Shape/Dtype Errors and Fails to Converge

You inherit a Transformer-based sequence model (decoder-only language-model objective) whose training run exhibits two symptoms simultaneously:

  1. It intermittently raises shape or dtype mismatch errors during training — most batches succeed, then one fails, often only after the run has been going for a while.
  2. It fails to converge after several thousand steps: the loss stalls on a plateau, oscillates, or diverges (sometimes spiking to NaN ).

Walk through your end-to-end debugging strategy. You should describe the order in which you would attack the problem, the concrete checks, metrics, and code-level instrumentation you would add at each stage, and how each check lets you confirm or rule out a subsystem (data, batching/collation, masking, positional encoding, the model forward pass, the optimizer/precision stack, checkpointing, and the distributed setup). Treat the two symptoms as possibly related (e.g. a silently-malformed batch could both break shapes and poison the loss) and explain how you would tell whether they share a root cause.

Constraints & Assumptions

  • This is a training problem (the convergence symptom rules out a pure inference issue), but the candidate may use generation/KV-cache reasoning to explain the shape crashes.
  • Assume access to the full source: model code, data pipeline, training loop, and config. You can add instrumentation and run experiments freely.
  • Assume a typical PyTorch stack: DataLoader + custom collate, AMP/ GradScaler mixed precision, AdamW, an LR scheduler with warmup, and the option of multi-GPU (DDP and/or ZeRO/DeepSpeed). Do not assume any specific finding is already known — the candidate must discover the root cause.
  • "Intermittent" means most steps succeed; the failure correlates with some batches or with elapsed time/state, not with every step.

Clarifying Questions to Ask

A strong candidate scopes the problem before diving in. Reasonable questions include:

  • Is this a fresh failure or a regression? Did it start after a code change, a data refresh, a tokenizer change, or a hardware/library upgrade?
  • Does the shape/dtype error reproduce on a single GPU and small batch , or only at scale / under DDP / under AMP?
  • What is the model objective and architecture precisely — decoder-only causal LM? Absolute, sinusoidal, or rotary positional embeddings? Is a KV cache used during training (e.g. chunked / streaming)?
  • Are the loss plateau and the crashes correlated in time (same step / same batch), or independent symptoms?
  • What does the loss curve actually look like — flat from step 0, slow decay then stall, or healthy-then-divergence? What is the current LR schedule and precision setting?
  • Is the tokenizer / dataset versioned and identical across all workers and any resumed checkpoint?

What a Strong Answer Covers

A strong answer is a systematic, ordered triage — not a flat list — that converts intermittent failures into deterministic, localized assertions. The interviewer is looking for breadth across the stack and depth on the high-yield culprits, with concrete code/metrics. Dimensions to look for:

  • Reproducibility & failure capture — seed control, deterministic flags, set_detect_anomaly , logging the failing batch index/inputs/device, and a per-step metrics panel (loss, grad norm, LR, scaler scale + overflow count).
  • Tokenization, padding & special tokens — single pinned tokenizer across workers/checkpoints, round-trip checks, correct PAD / BOS / EOS / UNK ids, padding side consistency, and sane sequence-length distributions (not silently truncated).
  • Collation invariants — right/left padding consistency; attention_mask = 1 for real tokens, 0 for PAD ; labels = ignore_index on PAD ; runtime asserts on the collate output.
  • Attention masks — causal vs. bidirectional correctness, mask shape/dtype, composing causal + padding masks, and masking in float32 before softmax (not large negatives in FP16).
  • Positional embeddings — position_ids within table bounds; for RoPE, consistent Q/K positions and correct past_len offset with a KV cache; behavior at/over max context length.
  • LM targets / label shifting — next-token shift correctness, no off-by-one, ignore_index on PAD, and a sanity check that the ignore_index fraction tracks the padding rate.
  • Gradients, optimizer & LR — global and per-layer grad norms, clipping, warmup, AdamW betas / eps , and excluding bias/LayerNorm from weight decay ; distinguishing exploding vs. vanishing grads from the curves.
  • Numerical stability & NaNs — assert_finite on logits/loss, float32 softmax/loss, epsilon choice, and reading the GradScaler (frequent overflow/step-skipping as a signal).
  • Data bugs — shard/tokenizer consistency, all- UNK / all- PAD / pathologically long samples, truncation that drops EOS or breaks label shifting.
  • Checkpoint/resume & RNG — saving/restoring model + optimizer + scheduler + scaler + sampler + global step; verifying loss continuity and per-rank deterministic sampler seeding after resume.
  • Multi-GPU / DDP / ZeRO / AMP — identical per-rank batch shapes, exactly-once gradient reduction, find_unused_parameters hygiene, framework-correct grad clipping under ZeRO, and per-rank logging.
  • Throughput & memory profiling — tokens/sec, dataloader-vs-compute split, torch.profiler hotspots, memory_summary , and OOM triage (batch size, gradient checkpointing, padding to multiples of 8/16).
  • Minimal reproducible example — a tiny deterministic synthetic-task setup used to bisect model/optimizer faults from data/distributed faults, plus a symptom→root-cause mapping.

Follow-up Questions

  • The crash only appears after ~3,000 steps and never in a fresh MRE. What classes of bugs become more likely with elapsed time/state, and how would you confirm one (e.g. KV-cache/position drift, scaler state, optimizer state, or a rare long sample that only appears late in an epoch)?
  • Your GradScaler is skipping a large fraction of steps. What does that tell you, and what is your remediation order before you reach for "turn off AMP"?
  • Suppose the MRE on synthetic data converges fine but the real run still stalls. Which subsystems does that exonerate, and where do you look next?
  • How would you write a regression test that catches a future reintroduction of the label-shift / mask-dtype bug in CI, so this class of failure can't silently return?
Loading comments...

Browse More Questions

More Machine Learning•More OpenAI•More Machine Learning Engineer•OpenAI Machine Learning Engineer•OpenAI Machine Learning•Machine Learning Engineer Machine Learning

Write your answer

Your first approved answer each day earns 20 XP.

Sign in to write your answer.
PracHub

Master your tech interviews with 8,000+ real questions from top companies.

Product

  • Questions
  • Learning Tracks
  • Interview Guides
  • Resources
  • Premium
  • For Universities
  • Student Access

Browse

  • By Company
  • By Role
  • By Category
  • Topic Hubs
  • SQL Questions
  • AI Coding Questions
  • Compare Platforms
  • Discord Community

Support

  • support@prachub.com
  • (916) 541-4762

Legal

  • Privacy Policy
  • Terms of Service
  • About Us

© 2026 PracHub. All rights reserved.