Implement Distributed Matrix Multiplication
Company: xAI
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: hard
Interview Round: Technical Screen
Quick Answer: This question evaluates distributed systems and parallel algorithm skills, specifically distributed matrix multiplication, data partitioning (row- and column-wise sharding), inter-device communication and synchronization, and practical NumPy/thread-based implementation competency.
Part 1: Data-Parallel Matrix Multiplication
Constraints
- 1 <= M, K, N <= 100
- 1 <= num_devices <= 100
- a is an M x K rectangular integer matrix
- b is a K x N rectangular integer matrix
- The row split of A must follow numpy.array_split semantics
Examples
Input: ([[1, 2], [3, 4], [5, 6], [7, 8]], [[1, 0, 2], [0, 1, 3]], 2)
Expected Output: [[1, 2, 8], [3, 4, 18], [5, 6, 28], [7, 8, 38]]
Explanation: A is split into two row chunks: [[1, 2], [3, 4]] and [[5, 6], [7, 8]]. Each chunk is multiplied by the full B, then the results are concatenated.
Input: ([[2, 0, 1], [1, 3, 2], [4, 1, 0]], [[1, 2], [0, 1], [3, 4]], 2)
Expected Output: [[5, 8], [7, 13], [4, 9]]
Explanation: The row split is uneven: the first device gets 2 rows and the second gets 1 row.
Input: ([[2, 3]], [[4], [5]], 4)
Expected Output: [[23]]
Explanation: Edge case: there are more devices than rows, so several devices get empty chunks.
Input: ([[1, 2, 3], [4, 5, 6]], [[7], [8], [9]], 1)
Expected Output: [[50], [122]]
Explanation: Edge case: with one device, the algorithm reduces to ordinary matrix multiplication.
Hints
- Compute the row-chunk sizes with quotient and remainder: the first remainder chunks are one row larger.
- After computing each chunk's product, gather the chunk results in the same device order used during the split.
Part 2: Fully Sharded Matrix Multiplication
Constraints
- 1 <= M, K, N <= 100
- 1 <= num_devices <= 100
- a is an M x K rectangular integer matrix
- b is a K x N rectangular integer matrix
- The row split of A and the column split of B must follow numpy.array_split semantics
- Some row shards or column shards may be empty when num_devices is larger than M or N
Examples
Input: ([[1, 2], [3, 4]], [[1, 2, 3, 4], [5, 6, 7, 8]], 2)
Expected Output: [[11, 14, 17, 20], [23, 30, 37, 44]]
Explanation: B is split into two column shards: the first two columns and the last two columns. After one rotation, each row shard has seen both column shards.
Input: ([[1, 0], [2, 1], [0, 3]], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], 3)
Expected Output: [[1, 2, 3, 4, 5], [8, 11, 14, 17, 20], [18, 21, 24, 27, 30]]
Explanation: This case has uneven column shards because 5 columns are split across 3 devices as sizes 2, 2, and 1.
Input: ([[2, 1, 0]], [[1, 2], [3, 4], [5, 6]], 4)
Expected Output: [[5, 8]]
Explanation: Edge case: there are more devices than rows and columns, so several shards are empty but the final answer is still correct.
Input: ([[2, 3], [4, 5]], [[6], [7]], 1)
Expected Output: [[33], [59]]
Explanation: Edge case: with one device, the sharded algorithm becomes standard matrix multiplication.
Hints
- Each rotating B shard needs to carry its original column-shard index so you know where to place the partial product.
- Let each device build a full-width local result block and fill only the columns covered by the shard it currently holds.