Implement scaled dot-product attention
Company: Meta
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: medium
Interview Round: Onsite
Quick Answer: This question evaluates a candidate's ability to implement single-head scaled dot-product attention, including correct matrix operations for Q, K, V, masked attention handling, and numerically stable softmax, testing competencies in linear algebra and machine learning model internals within the Coding & Algorithms / Machine Learning domain.
Constraints
- 1 <= d, dv (each row of Q/K has length d; each row of V has length dv)
- 0 <= Tq, Tk (Tq == 0 returns an empty list)
- K and V share the same first dimension Tk
- mask, when provided, has shape (Tq, Tk); a falsy entry forbids attending to that key
- All numeric inputs are real-valued floats; output values are rounded to 6 decimal places
Examples
Input: ([[1.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]], [[1.0, 2.0], [3.0, 4.0]], None)
Expected Output: [[1.660477, 2.660477]]
Explanation: Single query [1,0], d=2 so scale=sqrt(2). Raw scores = [1,0]; scaled = [0.7071, 0]. Softmax ~ [0.6698, 0.3302]; output = 0.6698*[1,2] + 0.3302*[3,4] = [1.660477, 2.660477].
Input: ([[0.0, 0.0]], [[5.0, -3.0], [2.0, 8.0]], [[10.0, 0.0], [0.0, 10.0]], None)
Expected Output: [[5.0, 5.0]]
Explanation: Query is all zeros, so every score is 0 regardless of keys; softmax gives uniform weights [0.5, 0.5]; output = 0.5*[10,0] + 0.5*[0,10] = [5,5].
Input: ([[1.0, 1.0], [2.0, -1.0]], [[1.0, 0.0], [0.0, 1.0]], [[1.0], [2.0]], [[1, 0], [1, 1]])
Expected Output: [[1.0], [1.107042]]
Explanation: Row 0 masks out key 1, so only key 0 is attended -> weight 1.0 -> output V[0]=[1.0]. Row 1 attends both keys with stable softmax over the scaled scores, giving [1.107042].
Input: ([[2.0, 2.0]], [[1.0, 1.0], [1.0, 1.0]], [[4.0, 8.0], [4.0, 8.0]], None)
Expected Output: [[4.0, 8.0]]
Explanation: Both keys are identical so scores tie; softmax is uniform [0.5,0.5], and both value rows are equal, so the output equals either value row: [4,8].
Input: ([], [[1.0]], [[1.0]], None)
Expected Output: []
Explanation: Edge case: empty query matrix (Tq=0) returns an empty output list.
Input: ([[3.0]], [[1.0], [2.0]], [[7.0, 1.0], [9.0, 2.0]], [[0, 0]])
Expected Output: [[8.0, 1.5]]
Explanation: Fully-masked row (both keys disallowed). Instead of NaN, fall back to a uniform distribution [0.5,0.5] over all keys: output = 0.5*[7,1] + 0.5*[9,2] = [8,1.5].
Hints
- Process one query row at a time: compute its Tk scores, softmax them, then combine the value rows.
- Numerically stable softmax: subtract the row's maximum *finite* score before exponentiating; this prevents overflow for large magnitudes and leaves the result unchanged.
- Treat a masked entry as score -inf, which exponentiates to weight 0 — but guard the fully-masked row separately (all -inf) to avoid dividing by zero; here we fall back to a uniform distribution.
- Don't forget the 1/sqrt(d) scaling on the raw dot products before softmax.