Implement attention and Transformer with backward pass
Company: Tesla
Role: Machine Learning Engineer
Category: Machine Learning
Difficulty: hard
Interview Round: Technical Screen
Implement scaled dot-product attention and a Transformer block from scratch (no autograd). Provide both forward and backward passes for the attention module; assume the backward for softmax is already implemented and available. Given input X of shape (B, T, d_model), multi-head attention with h heads (head_dim = d_model / h), learnable parameters W_q, W_k, W_v in R^{d_model x d_model} and W_o in R^{d_model x d_model}, and a causal mask for autoregressive training:
(
1) Forward: compute Q, K, V, split into heads, compute scores = Q K^T / sqrt(head_dim), apply causal mask, softmax, attention output, merge heads, and apply W_o. If building a full block, include residual, LayerNorm, and a position-wise feed-forward sublayer.
(
2) Backward: derive and implement gradients for W_q, W_k, W_v, W_o and for inputs (dX), including gradient flows through scaling, matmul operations, masking, head reshaping/transpose, and the output projection; use the provided softmax backward to map dScores to dLogits.
(
3) Autoregressive training: define the next-token cross-entropy loss with teacher forcing, show where the causal mask is applied, and compute perplexity.
(
4) Verification: implement finite-difference gradient checks on small tensors and report maximum relative errors; discuss numerical stability (e.g., stabilized softmax with log-sum-exp). Provide clear function signatures and tensor shapes at each step, plus complexity analysis.
Quick Answer: This question evaluates implementation and analytical skills for scaled dot-product multi-head self-attention and an encoder-style Transformer block, including manual forward and backward-pass computations, gradient derivation for projection matrices, causal masking for autoregressive next-token prediction, numerical stability of softmax, and time/memory complexity analysis within the Machine Learning domain, emphasizing practical implementation with underlying conceptual understanding of linear algebra and backpropagation mechanics. It is commonly asked to verify mastery of attention mechanisms and autodiff-free gradient reasoning, plus the ability to compute cross-entropy/perplexity, perform finite-difference gradient checks, and reason about algorithmic trade-offs.