PracHub
QuestionsPremiumCoachesLearningGuidesInterview Prep
|Home/Machine Learning/Apple

Implement Masked Multi-Head Self-Attention

Last updated: Jun 17, 2026

Quick Overview

This question evaluates implementation and conceptual understanding of masked multi-head self-attention, covering scaled dot-product attention, separate linear projections for queries/keys/values, head-wise tensor reshaping, and the construction and application of padding and causal masks.

  • easy
  • Apple
  • Machine Learning
  • Machine Learning Engineer

Implement Masked Multi-Head Self-Attention

Company: Apple

Role: Machine Learning Engineer

Category: Machine Learning

Difficulty: easy

Interview Round: Technical Screen

Implement the core self-attention module used inside a Transformer encoder, from scratch, in a deep-learning framework of your choice (PyTorch-style pseudocode is fine). Given an input tensor `X` of shape `(batch_size, sequence_length, d_model)`, first implement **scaled dot-product self-attention**, then extend it to **multi-head self-attention** with `num_heads` heads. A complete answer must: 1. Learn **separate** linear projections for queries, keys, and values. 2. Split the projected tensors into `num_heads` independent heads. 3. Compute attention scores as $\frac{QK^\top}{\sqrt{d_k}}$. 4. Apply an **optional** attention mask *before* the softmax so masked positions receive zero weight. 5. Apply softmax over the key dimension. 6. Multiply the attention weights by the values. 7. Concatenate all heads and apply a final output projection. State the **expected tensor shape at each major step**, and explain how you would construct and apply both a **padding mask** and a **causal (look-ahead) mask**. ```hint Where to start Get single-head scaled dot-product attention working first: $\text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V$. Multi-head is then "do this in parallel across a head axis," which is a `reshape` + `transpose`, not new math. ``` ```hint Reshape vs. multiple projections You do NOT need `num_heads` separate `Linear` layers. Use one `Linear(d_model, d_model)` per role (Q, K, V), then `view`/`reshape` the last dim into `(num_heads, d_k)` and move the head axis next to the batch axis so each head attends independently. Track how `(B, T, d_model)` becomes `(B, H, T, d_k)`. ``` ```hint Where the mask acts A mask changes which keys a query is allowed to see, so it belongs on the scores tensor before the softmax — not on the probabilities afterward. Think about the shape mismatch: a per-token padding signal and the score matrix don't have the same rank, so part of the work is figuring out how to line them up. ``` ```hint Two kinds of masking A padding mask and a causal (look-ahead) mask answer different questions — "is this key a real token?" vs. "is this key in the future relative to the query?" Decide what shape each one naturally has, how they compose when both apply, and whether your reshape back to `(B, T, d_model)` needs anything special after you've moved the head axis around. ``` ### Constraints & Assumptions - `d_model` is divisible by `num_heads`; define $d_k = d_{\text{model}} / \text{num\_heads}$. - Self-attention: queries, keys, and values are all derived from the same input `X` (so the query, key, and value sequence lengths are equal to `sequence_length`). - The mask is optional. When supplied it may be a padding mask of shape `(batch_size, sequence_length)`, a causal mask of shape `(sequence_length, sequence_length)`, or both combined. - Typical scale to keep in mind: `d_model` in the hundreds-to-thousands (e.g. 512), `num_heads` 8–16, sequence lengths up to a few thousand. Do not assume a fixed value—write the code parametrically. - Numerical stability matters: a masked-out logit must be driven low enough that softmax assigns it ~0 weight. ### Clarifying Questions to Ask - Is this **self**-attention (Q, K, V from the same `X`) or cross-attention (K, V from a different source)? This affects the projection inputs and the mask shapes. - What does the mask convention mean—does `True`/`1` indicate *keep* or *mask out*? Is it a `0/1` tensor or already an additive `-inf` bias? - Should the module include dropout on the attention weights, and/or a scaling/bias term in the projections? - Do we need to support variable-length sequences in a batch (i.e. a real padding mask), or are all sequences the same length? - Are we optimizing for clarity/correctness, or should I also discuss a fused/flash-attention-style kernel and memory complexity? - Should the output keep shape `(B, T, d_model)` to slot directly into a residual + LayerNorm block? ### What a Strong Answer Covers - **Correct math**: the $1/\sqrt{d_k}$ scaling and *why* it exists (controlling logit variance so softmax gradients don't vanish). - **Shape bookkeeping**: an explicit shape at every step—projection, head split, scores, masked scores, softmax, context, head merge, output projection. - **Single-projection-then-split** trick rather than per-head layers (efficiency and parameter parity). - **Mask correctness**: pre-softmax application, broadcasting a `(B, T)` padding mask, building a triangular causal mask, and combining the two. - **Numerical care**: large-negative fill instead of post-softmax zeroing; awareness of the fully-masked-row / all-`-inf` degenerate case. - **Framework hygiene**: `.contiguous()` before `view`, the assertion `d_model % num_heads == 0`, output shape matching input. - **Complexity awareness**: the attention-matrix cost is $O(T^2 \cdot d_{\text{model}})$ time and $O(H \cdot T^2)$ memory, while the Q/K/V/output projections add $O(T \cdot d_{\text{model}}^2)$; knowing which term dominates in which $T$-vs-$d_{\text{model}}$ regime, and where this becomes the bottleneck. ### Follow-up Questions - How does this change for **cross-attention** (e.g. in a decoder), where keys/values come from the encoder output of a different length? - Attention is $O(T^2)$ in sequence length. What approaches reduce this (FlashAttention's tiling/recomputation, or linear/sparse attention), and what do they trade off? - Where do **rotary** or other relative position embeddings (RoPE, ALiBi) plug into this module, and why might they be preferred over absolute positional encodings added to `X`? - An entire row of the attention logits is masked out (a fully-padded query position). What happens to the softmax, and how would you prevent `NaN`s from propagating? - Why is multi-head attention empirically better than a single head of the same total width, given the parameter count is roughly equal?

Quick Answer: This question evaluates implementation and conceptual understanding of masked multi-head self-attention, covering scaled dot-product attention, separate linear projections for queries/keys/values, head-wise tensor reshaping, and the construction and application of padding and causal masks.

Related Interview Questions

  • Compare DCN v1 vs v2 and A/B test - Apple (medium)
  • Explain dataset size, generalization, and U-Net skips - Apple (medium)
  • Analyze vision model failures - Apple (medium)
  • Compare audio preprocessing and training - Apple (medium)
  • Design Siri-vs-GPT query routing - Apple (medium)
Apple logo
Apple
Apr 4, 2026, 12:00 AM
Machine Learning Engineer
Technical Screen
Machine Learning
22
0
Loading...

Implement the core self-attention module used inside a Transformer encoder, from scratch, in a deep-learning framework of your choice (PyTorch-style pseudocode is fine).

Given an input tensor X of shape (batch_size, sequence_length, d_model), first implement scaled dot-product self-attention, then extend it to multi-head self-attention with num_heads heads. A complete answer must:

  1. Learn separate linear projections for queries, keys, and values.
  2. Split the projected tensors into num_heads independent heads.
  3. Compute attention scores as QK⊤dk\frac{QK^\top}{\sqrt{d_k}}dk​​QK⊤​ .
  4. Apply an optional attention mask before the softmax so masked positions receive zero weight.
  5. Apply softmax over the key dimension.
  6. Multiply the attention weights by the values.
  7. Concatenate all heads and apply a final output projection.

State the expected tensor shape at each major step, and explain how you would construct and apply both a padding mask and a causal (look-ahead) mask.

Constraints & Assumptions

  • d_model is divisible by num_heads ; define dk=dmodel/num_headsd_k = d_{\text{model}} / \text{num\_heads}dk​=dmodel​/num_heads .
  • Self-attention: queries, keys, and values are all derived from the same input X (so the query, key, and value sequence lengths are equal to sequence_length ).
  • The mask is optional. When supplied it may be a padding mask of shape (batch_size, sequence_length) , a causal mask of shape (sequence_length, sequence_length) , or both combined.
  • Typical scale to keep in mind: d_model in the hundreds-to-thousands (e.g. 512), num_heads 8–16, sequence lengths up to a few thousand. Do not assume a fixed value—write the code parametrically.
  • Numerical stability matters: a masked-out logit must be driven low enough that softmax assigns it ~0 weight.

Clarifying Questions to Ask

  • Is this self -attention (Q, K, V from the same X ) or cross-attention (K, V from a different source)? This affects the projection inputs and the mask shapes.
  • What does the mask convention mean—does True / 1 indicate keep or mask out ? Is it a 0/1 tensor or already an additive -inf bias?
  • Should the module include dropout on the attention weights, and/or a scaling/bias term in the projections?
  • Do we need to support variable-length sequences in a batch (i.e. a real padding mask), or are all sequences the same length?
  • Are we optimizing for clarity/correctness, or should I also discuss a fused/flash-attention-style kernel and memory complexity?
  • Should the output keep shape (B, T, d_model) to slot directly into a residual + LayerNorm block?

What a Strong Answer Covers

  • Correct math : the 1/dk1/\sqrt{d_k}1/dk​​ scaling and why it exists (controlling logit variance so softmax gradients don't vanish).
  • Shape bookkeeping : an explicit shape at every step—projection, head split, scores, masked scores, softmax, context, head merge, output projection.
  • Single-projection-then-split trick rather than per-head layers (efficiency and parameter parity).
  • Mask correctness : pre-softmax application, broadcasting a (B, T) padding mask, building a triangular causal mask, and combining the two.
  • Numerical care : large-negative fill instead of post-softmax zeroing; awareness of the fully-masked-row / all- -inf degenerate case.
  • Framework hygiene : .contiguous() before view , the assertion d_model % num_heads == 0 , output shape matching input.
  • Complexity awareness : the attention-matrix cost is O(T2⋅dmodel)O(T^2 \cdot d_{\text{model}})O(T2⋅dmodel​) time and O(H⋅T2)O(H \cdot T^2)O(H⋅T2) memory, while the Q/K/V/output projections add O(T⋅dmodel2)O(T \cdot d_{\text{model}}^2)O(T⋅dmodel2​) ; knowing which term dominates in which TTT -vs- dmodeld_{\text{model}}dmodel​ regime, and where this becomes the bottleneck.

Follow-up Questions

  • How does this change for cross-attention (e.g. in a decoder), where keys/values come from the encoder output of a different length?
  • Attention is O(T2)O(T^2)O(T2) in sequence length. What approaches reduce this (FlashAttention's tiling/recomputation, or linear/sparse attention), and what do they trade off?
  • Where do rotary or other relative position embeddings (RoPE, ALiBi) plug into this module, and why might they be preferred over absolute positional encodings added to X ?
  • An entire row of the attention logits is masked out (a fully-padded query position). What happens to the softmax, and how would you prevent NaN s from propagating?
  • Why is multi-head attention empirically better than a single head of the same total width, given the parameter count is roughly equal?

Solution

Show

Submit Your Answer to Earn 20XP

Sign in to leave a comment

Loading comments...

Browse More Questions

More Machine Learning•More Apple•More Machine Learning Engineer•Apple Machine Learning Engineer•Apple Machine Learning•Machine Learning Engineer Machine Learning
PracHub

Master your tech interviews with 8,000+ real questions from top companies.

Product

  • Questions
  • Learning Tracks
  • Interview Guides
  • Resources
  • Premium
  • For Universities
  • Student Access

Browse

  • By Company
  • By Role
  • By Category
  • Topic Hubs
  • SQL Questions
  • Compare Platforms
  • Discord Community

Support

  • support@prachub.com
  • (916) 541-4762

Legal

  • Privacy Policy
  • Terms of Service
  • About Us

© 2026 PracHub. All rights reserved.