Implement Distributed Matrix Multiplication
Company: xAI
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: hard
Interview Round: Technical Screen
You are given a simplified distributed-computing setup that simulates communication between devices using Python threads and message queues.
A helper class `Communicator` is available with the following behavior:
- `send(src, dst, data)`: send a NumPy array from one device to another.
- `recv(dst)`: block until device `dst` receives a message, then return `(src, data)`.
Implement two functions for matrix multiplication `C = A @ B` using `num_devices` simulated devices.
## Part 1: Data Parallel multiplication
Implement a per-device worker and the driver function for a data-parallel strategy:
- Split matrix `A` row-wise into `num_devices` chunks.
- Replicate the full matrix `B` to every device.
- Each device computes its local output `A_chunk @ B`.
- Gather all partial outputs to device `0` and concatenate them in the correct row order.
- Return the full result matrix.
Function signatures:
- `compute_fn(comm, rank, a_chunk, b, result)`
- `dp_mat_mul(a, b, num_devices)`
## Part 2: Fully Sharded multiplication
Implement `fsdp_mat_mul(a, b, num_devices)` from scratch using a sharded strategy:
- Split `A` row-wise across devices.
- Split `B` column-wise across devices.
- Each device initially owns one row shard of `A` and one column shard of `B`.
- Use an all-gather style rotation of `B` shards: during each round, devices exchange `B` shards so that after all rounds, every device has multiplied its local `A` shard against every column shard of `B`.
- Each device should accumulate the correct local output rows for its shard of `A`.
- Finally, gather the row shards and return the full matrix `C`.
## Requirements
- Use NumPy for matrix operations.
- Use Python threads to simulate devices.
- Use the provided `Communicator` for cross-device communication.
- The final output must match `A @ B`.
You may assume:
- `A` has shape `(M, K)`.
- `B` has shape `(K, N)`.
- `num_devices >= 1`.
- The row split of `A` and column split of `B` follow `numpy.array_split` semantics.
Your implementation should correctly handle uneven splits as long as the shapes remain compatible.
Quick Answer: This question evaluates distributed systems and parallel algorithm skills, specifically distributed matrix multiplication, data partitioning (row- and column-wise sharding), inter-device communication and synchronization, and practical NumPy/thread-based implementation competency.