Transformer Attention And Masking
Asked of: ML Engineer
Last updated

What's being tested
Candidates must demonstrate practical mastery of Transformer attention and masking semantics as they appear in real training and inference pipelines: correct causal/padding masking, target alignment for language-model and sequence-to-sequence objectives, and numerical/stability issues that break training. Interviewers probe ability to triage pipeline failures (loss plateauing, NaNs, incorrect evaluation) by isolating tokenization, mask construction, loss configuration, and mixed-precision interactions. For an ML Engineer at OpenAI, the focus is operational: design reproducible tests, implement minimal unit tests for masking/shift logic, and choose pragmatic fixes that preserve offline/online parity.
Core knowledge
-
Scaled dot-product attention formula: where M is the additive mask (large negative for disallowed positions).
-
Causal mask (autoregressive): triangular boolean or additive mask (use
torch.triu), shape broadcastable to ; ensures position t cannot attend to >t. -
Padding mask: per-token mask (1 for real tokens, 0 for padding) applied either by zeroing embeddings or by adding to attention logits, and by using
ignore_indexin loss (e.g.CrossEntropyLoss(ignore_index=pad_id)). -
Target shift (teacher forcing): for causal LM, inputs are tokens [x0..x_{T-1}] and targets are [x1..x_T]; off-by-one here causes perfect or zero loss depending on bug—always unit-test with known sequence.
-
Mask combination: combine causal and padding masks via logical AND (or additive sums) so padded positions are blocked even if causal allows; pay attention to dtype/broadcast order.
-
KV cache for inference: cache per-layer keys/values with shapes
(layer, batch, seq_len_cached, d_k)to avoid recomputing past K,V; memory grows O(seq_len * d_model) and complicates relative positional encodings. -
Relative positional encodings: many implementations (
T5,Transformer-XL,ALiBi) change how attention is computed — KV cache and caching logic must respect that; naive caching can break positional bias. -
Numerical stability & AMP: use gradient scaling (
torch.cuda.amp) to avoid underflow, monitorgradnorm, and check for NaNs by running forward/backward on FP32; mixed-precision can hide mask dtype bugs (bool vs float). -
Loss config mismatches: label smoothing, reduction (
sumvsmean), andignore_indexchange loss magnitude—verify metric (perplexity) calculation uses correct token counts: -
Unit tests you should have: deterministic mini-batch where attention mask yields known zeroed logits, toy sequence where shifted targets match, and incremental decode that matches full-sequence decode when replaying cache.
-
Performance tradeoffs: full re-attention per step is O(T^2); KV caching reduces complexity per new token to O(T) but increases memory and complicates batch merging/beam search.
-
Implementation idioms: use
masked_fill(mask==0, -1e9)for floats ortorch.where; prefer boolean masks for clarity; checkdtypeand device when constructing masks.
Worked example — Debug a transformer training pipeline
Start by clarifying: "Is the model trained with a causal LM objective or seq2seq? Which tokenizer and pad_id are used? Are we using AMP or distributed training?" A strong candidate organizes the investigation into three pillars: (1) data & token alignment (tokenizer, pad id, target shift), (2) mask correctness (causal vs padding and how they're combined), and (3) numerical/optimizer issues (LR, gradient clipping, AMP scaling). Practically, they'd run small-unit tests: forward on a single, known sequence to verify target-shift behavior and attention-mask effect; compute loss with and without padding to detect ignore_index misuse.
Pitfall: Enabling aggressive label smoothing can hide token-level bugs by lowering loss variance, so temporarily disable it while debugging.
Close by proposing automated checks (unit tests above) and next steps: if root cause remains, run FP32 debug trace and isolate whether NaNs originate in softmax (mask error) or optimizer step (LR spike).
A second angle — Explain KV cache in Transformer inference
When focusing on KV caching, the problem shifts from training correctness to latency/memory engineering and parity with training. Emphasize shapes and bookkeeping: each decode step appends new K,V per layer; implementations must ensure cached K,V and new Q use the same positional encoding semantics. Key pitfalls include forgetting to detach cached tensors (causing autograd buildup) and failing to handle relative positional encodings where attention depends on relative indexes rather than absolute cached positions. Beam search and batch-merge strategies complicate caching because different beams have different cache lengths—design must support copy-on-write or packed-batch cache reshaping. An ML Engineer will measure memory vs latency tradeoffs, and add instrumentation to report per-layer cache size and hit rates.
Common pitfalls
Pitfall: Assuming boolean masks "just work" across ops.
Many bugs come from mask dtype/device mismatches or wrong broadcast order (shape (B,1,1,T) vs (B,1,T,T)), which silently broadcast and allow illegal attend. Always assert mask shapes and mask.sum() sanity.
Pitfall: Not asking the objective early.
Candidates who fail to confirm whether it's causal LM vs encoder‑decoder often propose incorrect shift/masking fixes; explicitly state the loss alignment you expect and how you'd verify it.
Pitfall: Treating AMP NaNs as model bugs.
Mixed precision can expose gradient underflow/overflow; reproduce the failure in FP32 before changing masking or architecture. If FP32 is fine, focus on scaler/optimizer config rather than attention math.
Connections
This topic commonly connects to efficient attention algorithms (e.g., FlashAttention) and to decoding strategies (beam search, sampling) where masking and caching interact. It also links to model evaluation metrics (perplexity, token-level accuracy) and to serving concerns like batching and memory accounting for cached KV states.
Further reading
-
Attention Is All You Need (Vaswani et al., 2017) — foundational paper for scaled dot-product attention and masking semantics.
-
The Illustrated Transformer (Jay Alammar) — concise visual intuition for masks, attention, and decoder behavior.
-
FlashAttention (Dao et al.) — practical paper on fast, memory-efficient attention that interacts with caching and masking.
Practice questions
- Debug a broken Transformer implementationOpenAI · Machine Learning Engineer · Onsite · hard
- Explain KV cache in Transformer inferenceOpenAI · Machine Learning Engineer · Onsite · medium
- Debug a transformer training pipelineOpenAI · Machine Learning Engineer · Technical Screen · hard
- Diagnose Transformer training and inference bugsOpenAI · Machine Learning Engineer · Technical Screen · hard
- Debug transformer and train classifierOpenAI · Machine Learning Engineer · Technical Screen · hard
- Debug a transformer training pipelineOpenAI · Machine Learning Engineer · Technical Screen · hard
Related concepts
- Transformer Architectures And AttentionMachine Learning
- Transformer Architecture and Attention Internals
- Transformer Training Pipeline DebuggingMachine Learning
- Generative AI Training, Attention, And Post-TrainingML System Design
- LLM Foundations, Embeddings, Prompts, And Fine-Tuning
- ML Fundamentals: Backprop, Attention, And RLMachine Learning