Transformer Architecture and Attention Internals
Asked of: ML Engineer
Last updated
What's being tested
Candidates must demonstrate concrete mastery of self-attention math, compute/memory complexity, and practical tradeoffs when training or serving large transformer models. Interviewers probe whether you can derive and reason about the scaled dot-product attention formula, its gradient flow, implementation bottlenecks (memory, parallelism), and mitigation strategies used in production training/serving. OpenAI cares because MLEs must diagnose training instabilities, optimize throughput (GPU/TPU), and choose attention variants that meet latency and accuracy targets.
Core knowledge
-
Scaled dot-product attention formula and shape calculus: where
Q,K,Vare (B, L, d) or (B, H, L, d_k) after head split; final projection returns (B, L, d_model). -
Multi-head attention: split into H heads of dimension
d_k=d_model/H; enables learning of different subspace correlations; costs extra memory for per-head activations and projections. -
Time & memory complexity: dense attention is O(L^2·d) compute and O(L^2) activation memory per batch; practical limits ~L up to 4–8k on commodity
GPUswithout special kernels; beyond that consider sparsity or chunking. -
Softmax numerical stability: subtract max per row before exponentiating; without this you get NaNs during training with large logits or mixed precision.
-
Scaling factor
1/√d_kavoids tiny gradients when dot products grow withd_k; omitting it causes gradient vanishing/exploding across heads. -
Masks: causal mask enforces autoregressive dependency by setting future logits to
-inf; attention masking also allows padding removal to save compute if implemented carefully. -
Position information: positional encoding (sin/cos) vs relative positional bias; relative methods (
T5) help generalize to longer sequences and reduce dependency on absolute positions. -
Training stability primitives: LayerNorm placement (pre-LN vs post-LN), residual connections, learning-rate warmup, and optimizers like
Adamwith weight decay; pre-LN often improves stability for very deep stacks. -
Efficient kernels & algorithms:
FlashAttentionreduces memory by computing softmax in fused kernels;Reformeruses locality-sensitive hashing to get ~O(L log L) behavior; linear-attention approximations give O(L·d) but can alter expressivity. -
Memory/time engineering: use mixed precision (
bfloat16/float16), gradient checkpointing (trade compute for memory), tensor/model parallelism (Megatron-style), and activation partitioning to scale to large models. -
Gradients through attention: derivative wrt
Q,Kinvolves chain rule through softmax and matmuls; attention weight Jacobian is dense — watch forO(L^2)backprop memory and compute. -
Serving implications: causal decoding uses incremental state caching of
K,Vper past token to reduce compute to O(L · d) per new token; implementing efficient past-key-value cache is critical for latency.
Worked example
(Example interview prompt: "Explain and derive the scaled dot-product attention computation, its complexity, and implementation choices for training a 1B-parameter model on GPUs.")
First 30 seconds: clarify shapes (batch size B, sequence length L, heads H, per-head dim d_k), whether attention is encoder-only or autoregressive, and target hardware (single GPU vs multi-GPU). Skeleton answer pillars: (1) write the equation and shapes, (2) derive compute and memory costs (forward and backward), (3) list numerical/stability and parallelism issues, (4) propose concrete optimizations for training/serving. Explicit tradeoff: use FlashAttention or fused kernels to cut activation memory and avoid gradient checkpointing, but that requires specific CUDA/XLA support and may complicate multi-node parallelism. Close by noting measurable checks: profile peak GPU memory with a toy batch, validate logits distribution for softmax stability, and say "if I had more time I'd prototype a FlashAttention kernel and compare throughput and memory footprint against naive matmul+softmax on our target cluster."
A second angle
(Alternate prompt: "How do you make attention work for very long sequences (e.g., L=100k) for retrieval-augmented generation?")
Same base concept (attention matrix) => different constraints drive solutions: (1) use sparse attention patterns (local + strided/global) to reduce O(L^2) memory to approximately O(L·w) where w is window, preserving locality, (2) apply chunked attention with overlapping windows and recombination to trade accuracy for O(L·window) compute, (3) use linearized attention or kernel approximations to get O(L·d) but validate on downstream tasks since these can change model inductive biases. For autoregressive decode, maintain incremental caches and consider retrieval/attention hybrid (attend only to retrieved passages plus recent context) to make latency predictable. Be explicit about evaluation: always benchmark perplexity/ROUGE tradeoffs after replacing dense attention.
Common pitfalls
Pitfall: Dropping the 1/√d_k scaling when explaining attention.
Many candidates recite the formula but omit the scaling; that omission leads to wrong reasoning about gradient magnitudes and can break training intuition. Always explain why scaling exists and its effect on softmax saturation.
Pitfall: Treating attention as free in backprop.
People forget backprop is O(L^2) because the Jacobian of softmax is dense; this shows up as memory blowups during training. State the backward complexity and mitigation (fused kernels, checkpointing, FlashAttention).
Pitfall: Vagueness about masking and serving implications.
Saying "we mask future tokens" without describing incremental cache design or padding-aware batching makes your answer shallow. Explain how caching K,V reduces per-token cost and how to manage memory/eviction in long sessions.
Connections
Attention internals commonly lead to adjacent topics like efficient transformer variants (FlashAttention, Reformer, sparse attention) and model-parallel training strategies (tensor vs pipeline parallelism). Interviewers may pivot to implementation-level performance debugging (profiling CUDA kernels, memory fragmentation) or to quantization/pruning tradeoffs for latency-sensitive serving.
Further reading
-
Attention Is All You Need — original transformer paper; essential for formulas and architecture.
-
FlashAttention (paper & implementation) — practical high-throughput, memory-efficient attention kernel used in production.
-
Reformer: The Efficient Transformer — LSH-based attention that reduces complexity for long sequences.
Related concepts
- Transformer Attention And MaskingMachine Learning
- Transformer Architectures And AttentionMachine Learning
- Safety, Alignment, Guardrails, and Responsible LLM Deployment
- LLM Serving, Inference Scaling, KV Cache, and Latency-Cost Tradeoffs
- LLM Architecture, Tuning, And EvaluationMachine Learning
- Multimodal LLM System Design