Implement Multi-Head Attention from Scratch in NumPy
Company: Amazon
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: medium
Interview Round: Onsite
Quick Answer: This coding question tests a machine learning engineer's understanding of Transformer internals by requiring a from-scratch NumPy implementation of multi-head scaled dot-product attention. It evaluates mastery of linear projections, head splitting, numerically stable softmax, and causal masking — key competencies for ML roles involving large language models.
Constraints
- 1 <= seq_len <= 512
- 1 <= d_model <= 1024
- 1 <= num_heads <= 16
- d_model % num_heads == 0
- All weight matrices are square (d_model, d_model).
- mask, when provided, is (seq_len, seq_len) with 0/1 entries; the same mask applies to every head.
- Inputs are real-valued float64; return a float64 matrix.
- NumPy only — no PyTorch/TensorFlow/JAX or pre-built attention/softmax.
- Softmax must be numerically stable (subtract the per-row max).
- Scaling factor is 1/sqrt(d_h), the per-head dimension.
Examples
Input: ([[1.6243, -0.6118, -0.5282, -1.073, 0.8654, -2.3015, 1.7448, -0.7612], [0.319, -0.2494, 1.4621, -2.0601, -0.3224, -0.3841, 1.1338, -1.0999], [-0.1724, -0.8779, 0.0422, 0.5828, -1.1006, 1.1447, 0.9016, 0.5025], [0.9009, -0.6837, -0.1229, -0.9358, -0.2679, 0.5304, -0.6917, -0.3968]], [[-0.4168, -0.0563, -2.1362, 1.6403, -1.7934, -0.8417, 0.5029, -1.2453], [-1.058, -0.909, 0.5515, 2.2922, 0.0415, -1.1179, 0.5391, -0.5962], [-0.0191, 1.175, -0.7479, 0.009, -0.8781, -0.1564, 0.2566, -0.9888], [-0.3388, -0.2362, -0.6377, -1.1876, -1.4212, -0.1535, -0.2691, 2.2314], [-2.4348, 0.1127, 0.3704, 1.3596, 0.5019, -0.8442, 0.0, 0.5424], [-0.3135, 0.771, -1.8681, 1.7312, 1.4677, -0.3357, 0.6113, 0.048], [-0.8291, 0.0877, 1.0004, -0.3811, -0.3757, -0.0745, 0.4335, 1.2784], [-0.6347, 0.5084, 0.2161, -1.8586, -0.4193, -0.1323, -0.0396, 0.326]], [[1.7886, 0.4365, 0.0965, -1.8635, -0.2774, -0.3548, -0.0827, -0.627], [-0.0438, -0.4772, -1.3139, 0.8846, 0.8813, 1.7096, 0.05, -0.4047], [-0.5454, -1.5465, 0.9824, -1.1011, -1.185, -0.2056, 1.4861, 0.2367], [-1.0238, -0.713, 0.6252, -0.1605, -0.7688, -0.23, 0.7451, 1.9761], [-1.2441, -0.6264, -0.8038, -2.4191, -0.9238, -1.0239, 1.124, -0.1319], [-1.6233, 0.6467, -0.3563, -1.7431, -0.5966, -0.5886, -0.8739, 0.0297], [-2.2483, -0.2678, 1.0132, 0.8528, 1.1082, 1.1194, 1.4875, -1.1183], [0.8458, -1.8609, -0.6029, -1.9145, 1.0481, 1.3337, -0.1974, 1.7746]], [[0.0506, 0.5, -0.9959, 0.6936, -0.4183, -1.5846, -0.6477, 0.5986], [0.3323, -1.1475, 0.6187, -0.088, 0.4251, 0.3323, -1.1568, 0.351], [-0.6069, 1.547, 0.7233, 0.0461, -0.983, 0.0544, 0.1599, -1.2089], [2.2234, 0.3943, 1.6924, -1.1128, 1.6357, -1.361, -0.6512, 0.5425], [0.048, -2.3581, -1.1056, 0.8378, 2.0879, 0.9148, -0.2762, 0.7965], [-1.1438, 0.5099, -1.3475, -0.0094, -0.1307, 0.8021, -0.303, 1.202], [-0.1967, 0.8365, 0.7866, -1.8409, 0.0375, 0.0359, -0.7787, 0.1794], [-1.4555, 0.5562, 0.5098, 0.3004, 2.4766, 0.3523, 0.0675, -0.7323]], [[0.4412, -0.3309, 2.4308, -0.2521, 0.1096, 1.5825, -0.9092, -0.5916], [0.1876, -0.3299, -1.1928, -0.2049, -0.3588, 0.6035, -1.6648, -0.7002], [1.1514, 1.8573, -1.5112, 0.6448, -0.9806, -0.8569, -0.8719, -0.4225], [0.9964, 0.7124, 0.0591, -0.3633, 0.0033, -0.1059, 0.7931, -0.6316], [-0.0062, -0.1011, -0.0523, 0.2492, 0.1977, 1.3348, -0.0869, 1.5615], [-0.3059, -0.4777, 0.1007, 0.3554, 0.2696, 1.292, 1.1393, 0.4944], [-0.3363, -0.1006, 1.4134, 0.2213, -1.3108, -0.6896, -0.5775, 1.1522], [-0.1072, 2.2601, 0.6566, 0.1248, -0.4357, 0.9722, -0.2407, -0.8241]], 2, None)
Expected Output: [[-1.6774464060400325, -2.3308259808018166, -10.147355089761813, -0.03710910832461859, -5.297247101988509, -7.9985667447919475, -10.395951315328318, -8.001171608937256], [-1.3381561276571692, -4.658096498728458, 3.570946110390738, -2.189729088899968, 0.8997262291489199, -8.128793103268073, 2.32608358128664, -5.798166962397295], [-2.745328322448973, -4.688335517535206, 0.5400229118943058, -3.6299586435690294, 3.159016300993926, -1.0848650517369798, 5.533654624015087, -0.506179438622], [-1.8086656391548803, -5.9941660261373, 2.639222681897024, -3.603045453256781, 2.10117262664187, -7.824204113496304, 4.380985175339585, -5.790013915222155]]
Explanation: seq_len=4, d_model=8, num_heads=2 (d_h=4), no mask. Standard random forward pass — verifies the full projection -> per-head split -> scaled dot-product -> softmax -> concat -> output-projection path.
Input: ([[1.3316, 0.7153, -1.5454, -0.0084], [0.6213, -0.7201, 0.2655, 0.1085], [0.0043, -0.1746, 0.433, 1.203]], [[1.7495, -0.2861, -0.4846, -2.6533], [-0.0083, -0.3196, -0.5366, 0.3154], [0.4211, -1.0656, -0.8862, -0.4757], [0.6897, 0.5612, -1.3055, -1.1195]], [[0.473, -0.6814, 0.2424, -1.7007], [0.7531, -1.5347, 0.0051, -0.1202], [-0.807, 2.8718, -0.5978, 0.4725], [1.096, -1.2152, 1.3424, -0.1221]], [[-0.7124, 0.7538, -0.0445, 0.4518], [1.3451, 0.5323, 1.3502, 0.8612], [1.4787, -1.0454, -0.789, -1.2616], [0.5628, -0.2433, 0.9137, 0.3174]], [[1.5513, 0.0792, 0.174, -0.0723], [-2.0043, 0.1447, -1.5012, 0.2111], [-0.5582, 1.0845, -0.1863, 0.0147], [-1.0756, 0.6423, -0.1803, 0.6203]], 1, [[1, 0, 0], [1, 1, 0], [1, 1, 1]])
Expected Output: [[-14.134777783596, 4.583941245019, -5.868077101530002, 2.7924963705800003], [-13.993535039561984, 4.522066533232772, -5.805753577295718, 2.761269855204458], [-9.879523498200859, 2.9035769642868234, -4.015431321004437, 1.8861799397799646]]
Explanation: seq_len=3, d_model=4, num_heads=1, with a causal (lower-triangular) mask. Position 0 attends only to itself; position 1 to {0,1}; position 2 to all. Confirms masked entries get ~0 post-softmax weight.
Input: ([[0.8839, 0.1959, 0.3575, -2.3433]], [[-0.052, -0.1112, 1.0418, -1.2567], [0.7454, -1.7111, -0.2059, -0.2346], [1.1281, -0.0126, -0.6132, 1.3737], [1.611, -0.6892, 0.6919, -0.4481]], [[-0.0919, -1.4634, 1.0818, -0.2393], [-0.4911, -1.0023, 0.9188, -1.1036], [0.6265, -0.5615, 0.0289, -0.2308], [0.5878, 0.7523, -1.0585, 1.056]], [[0.667, 0.0258, -0.7776, 0.9486], [0.7017, -1.0511, -0.3675, -1.1375], [-1.3221, 1.7723, -0.3475, 0.6701], [0.3223, 0.0603, -1.0435, -1.0099]], [[1.3292, -0.77, -0.3163, -0.9908], [-1.0708, -1.4387, 0.5644, 0.2957], [-1.6264, 0.2196, 0.6788, 1.8893], [0.9615, 0.104, -0.4812, 0.8502]], 2, None)
Expected Output: [[-0.4391145576630004, 0.618841842283, -0.1572706040650001, 6.277270940486001]]
Explanation: Edge case: seq_len=1, d_model=4, num_heads=2, no mask. A single query attends only to itself, so each head's softmax is trivially 1.0 and the output is the value-projection (then output-projected) of that single token.
Input: ([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]], [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]], 2, None)
Expected Output: [[0.6697615493266569, 0.3302384506733431, 0.0, 0.0], [0.3302384506733431, 0.6697615493266569, 0.0, 0.0]]
Explanation: Interpretable case: one-hot inputs with identity projection weights, num_heads=2 (d_h=2), no mask. The output rows show how each token's softmax distributes weight between the two tokens within each head.
Input: ([[-1.2641, 1.5279, -0.9707, 0.4706], [-0.1007, 0.3038, -1.726, 1.5851], [0.1343, -1.1069, 1.5782, 0.1075], [-0.764, -0.7752, 1.3838, 0.7604]], [[-0.4148, -0.3334, 0.0811, -0.791], [-0.2186, -0.7632, -0.7771, 1.8494], [-0.7056, -0.086, 0.2879, -0.1314], [-0.9827, -0.9188, 1.1994, -0.3414]], [[-0.3489, 0.9837, 0.5809, 0.0703], [0.7775, 0.582, 1.4718, 1.6632], [-0.2612, -0.6887, -0.6949, 1.9404], [1.8054, 0.4563, -0.5748, 0.1142]], [[-0.3189, -1.603, -1.5352, -0.5704], [-0.2167, 0.2549, -0.1494, 2.0108], [-0.0968, 0.4222, -0.2255, -0.6379], [-0.0163, 1.0442, -1.0849, -2.2059]], [[0.2438, -0.7473, -1.5612, -0.4643], [-0.3521, -1.2815, 0.2893, 0.98], [0.4779, 0.4508, 0.7524, -0.5106], [-0.7058, -0.4243, -0.2322, 1.8151]], 2, [[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]])
Expected Output: [[-2.543462270274, -4.110033183436, 0.7607251173969999, 7.773488717141], [-2.1399012265762902, -3.716706213850191, 0.600477708698735, 6.446191068229188], [0.40369605959343324, -0.6619269071968319, -0.4262948685336292, -1.8345314587086328], [-0.07865437303802453, -0.9837336978783758, -0.27608868032046113, -0.21315151748291822]]
Explanation: seq_len=4, d_model=4, num_heads=2, with a causal mask. Combines multi-head splitting with masking so the mask correctly applies independently within each head.
Hints
- Do the three projections first as full matrix multiplies, then slice each (seq_len, d_model) result into num_heads column-blocks of width d_h = d_model // num_heads.
- For each head, scores = (Q_i @ K_i.T) / sqrt(d_h) has shape (seq_len, seq_len) — rows are queries, columns are keys. Softmax over the last axis (axis=-1).
- Apply the mask BEFORE the softmax: set masked (mask==0) score entries to a large negative number (e.g. -1e9) so exp() makes them ~0.
- Numerically-stable softmax: subtract scores.max(axis=-1, keepdims=True) before exp, then divide by the row sum.
- Concatenate the per-head (seq_len, d_h) outputs along the feature axis back to (seq_len, d_model), then apply the output projection W_o. In Python, accept lists, do the math with np.asarray, and return out.tolist().