You are given two compatible matrices A and B and an integer num_devices. Simulate a fully sharded matrix multiplication strategy for C = A x B.
Split A row-wise into num_devices shards using numpy.array_split semantics. Split B column-wise into num_devices shards using the same rule on columns. Device i initially owns row shard A_i and column shard B_i.
Now simulate num_devices rounds of rotation. In each round, every device multiplies its fixed row shard of A with the B shard it currently holds, writes that partial product into the correct global column positions, and then passes its current B shard to the next device in a circular fashion.
After all rounds, gather the row shards in device order and return the full matrix.
For this standalone problem, you only need to simulate the sharding and rotation logic. Real threads or message queues are not required.
Examples
Input: ([[1, 2], [3, 4]], [[1, 2, 3, 4], [5, 6, 7, 8]], 2)
Expected Output: [[11, 14, 17, 20], [23, 30, 37, 44]]
Explanation: B is split into two column shards: the first two columns and the last two columns. After one rotation, each row shard has seen both column shards.
Input: ([[1, 0], [2, 1], [0, 3]], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], 3)
Expected Output: [[1, 2, 3, 4, 5], [8, 11, 14, 17, 20], [18, 21, 24, 27, 30]]
Explanation: This case has uneven column shards because 5 columns are split across 3 devices as sizes 2, 2, and 1.
Input: ([[2, 1, 0]], [[1, 2], [3, 4], [5, 6]], 4)
Expected Output: [[5, 8]]
Explanation: Edge case: there are more devices than rows and columns, so several shards are empty but the final answer is still correct.
Input: ([[2, 3], [4, 5]], [[6], [7]], 1)
Expected Output: [[33], [59]]
Explanation: Edge case: with one device, the sharded algorithm becomes standard matrix multiplication.