PracHub
QuestionsPremiumCoachesLearningGuidesInterview Prep

Quick Overview

This question evaluates distributed systems and parallel algorithm skills, specifically distributed matrix multiplication, data partitioning (row- and column-wise sharding), inter-device communication and synchronization, and practical NumPy/thread-based implementation competency.

  • hard
  • xAI
  • Coding & Algorithms
  • Machine Learning Engineer

Implement Distributed Matrix Multiplication

Company: xAI

Role: Machine Learning Engineer

Category: Coding & Algorithms

Difficulty: hard

Interview Round: Technical Screen

You are given a simplified distributed-computing setup that simulates communication between devices using Python threads and message queues. A helper class `Communicator` is available with the following behavior: - `send(src, dst, data)`: send a NumPy array from one device to another. - `recv(dst)`: block until device `dst` receives a message, then return `(src, data)`. Implement two functions for matrix multiplication `C = A @ B` using `num_devices` simulated devices. ## Part 1: Data Parallel multiplication Implement a per-device worker and the driver function for a data-parallel strategy: - Split matrix `A` row-wise into `num_devices` chunks. - Replicate the full matrix `B` to every device. - Each device computes its local output `A_chunk @ B`. - Gather all partial outputs to device `0` and concatenate them in the correct row order. - Return the full result matrix. Function signatures: - `compute_fn(comm, rank, a_chunk, b, result)` - `dp_mat_mul(a, b, num_devices)` ## Part 2: Fully Sharded multiplication Implement `fsdp_mat_mul(a, b, num_devices)` from scratch using a sharded strategy: - Split `A` row-wise across devices. - Split `B` column-wise across devices. - Each device initially owns one row shard of `A` and one column shard of `B`. - Use an all-gather style rotation of `B` shards: during each round, devices exchange `B` shards so that after all rounds, every device has multiplied its local `A` shard against every column shard of `B`. - Each device should accumulate the correct local output rows for its shard of `A`. - Finally, gather the row shards and return the full matrix `C`. ## Requirements - Use NumPy for matrix operations. - Use Python threads to simulate devices. - Use the provided `Communicator` for cross-device communication. - The final output must match `A @ B`. You may assume: - `A` has shape `(M, K)`. - `B` has shape `(K, N)`. - `num_devices >= 1`. - The row split of `A` and column split of `B` follow `numpy.array_split` semantics. Your implementation should correctly handle uneven splits as long as the shapes remain compatible.

Quick Answer: This question evaluates distributed systems and parallel algorithm skills, specifically distributed matrix multiplication, data partitioning (row- and column-wise sharding), inter-device communication and synchronization, and practical NumPy/thread-based implementation competency.

Part 1: Data-Parallel Matrix Multiplication

You are given two compatible matrices A and B and an integer num_devices. Simulate a data-parallel matrix multiplication strategy for C = A x B. Split A row-wise into num_devices chunks using numpy.array_split semantics: if M = len(A), then the first M % num_devices chunks get one extra row, and some chunks may be empty when num_devices > M. Each simulated device receives one row chunk of A and the full matrix B, computes its local product, and the final answer is obtained by concatenating the local outputs in device order. For this standalone problem, you only need to reproduce the same partitioning and gathering behavior. You do not need real threads or message queues.

Constraints

  • 1 <= M, K, N <= 100
  • 1 <= num_devices <= 100
  • a is an M x K rectangular integer matrix
  • b is a K x N rectangular integer matrix
  • The row split of A must follow numpy.array_split semantics

Examples

Input: ([[1, 2], [3, 4], [5, 6], [7, 8]], [[1, 0, 2], [0, 1, 3]], 2)

Expected Output: [[1, 2, 8], [3, 4, 18], [5, 6, 28], [7, 8, 38]]

Explanation: A is split into two row chunks: [[1, 2], [3, 4]] and [[5, 6], [7, 8]]. Each chunk is multiplied by the full B, then the results are concatenated.

Input: ([[2, 0, 1], [1, 3, 2], [4, 1, 0]], [[1, 2], [0, 1], [3, 4]], 2)

Expected Output: [[5, 8], [7, 13], [4, 9]]

Explanation: The row split is uneven: the first device gets 2 rows and the second gets 1 row.

Input: ([[2, 3]], [[4], [5]], 4)

Expected Output: [[23]]

Explanation: Edge case: there are more devices than rows, so several devices get empty chunks.

Input: ([[1, 2, 3], [4, 5, 6]], [[7], [8], [9]], 1)

Expected Output: [[50], [122]]

Explanation: Edge case: with one device, the algorithm reduces to ordinary matrix multiplication.

Hints

  1. Compute the row-chunk sizes with quotient and remainder: the first remainder chunks are one row larger.
  2. After computing each chunk's product, gather the chunk results in the same device order used during the split.

Part 2: Fully Sharded Matrix Multiplication

You are given two compatible matrices A and B and an integer num_devices. Simulate a fully sharded matrix multiplication strategy for C = A x B. Split A row-wise into num_devices shards using numpy.array_split semantics. Split B column-wise into num_devices shards using the same rule on columns. Device i initially owns row shard A_i and column shard B_i. Now simulate num_devices rounds of rotation. In each round, every device multiplies its fixed row shard of A with the B shard it currently holds, writes that partial product into the correct global column positions, and then passes its current B shard to the next device in a circular fashion. After all rounds, gather the row shards in device order and return the full matrix. For this standalone problem, you only need to simulate the sharding and rotation logic. Real threads or message queues are not required.

Constraints

  • 1 <= M, K, N <= 100
  • 1 <= num_devices <= 100
  • a is an M x K rectangular integer matrix
  • b is a K x N rectangular integer matrix
  • The row split of A and the column split of B must follow numpy.array_split semantics
  • Some row shards or column shards may be empty when num_devices is larger than M or N

Examples

Input: ([[1, 2], [3, 4]], [[1, 2, 3, 4], [5, 6, 7, 8]], 2)

Expected Output: [[11, 14, 17, 20], [23, 30, 37, 44]]

Explanation: B is split into two column shards: the first two columns and the last two columns. After one rotation, each row shard has seen both column shards.

Input: ([[1, 0], [2, 1], [0, 3]], [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], 3)

Expected Output: [[1, 2, 3, 4, 5], [8, 11, 14, 17, 20], [18, 21, 24, 27, 30]]

Explanation: This case has uneven column shards because 5 columns are split across 3 devices as sizes 2, 2, and 1.

Input: ([[2, 1, 0]], [[1, 2], [3, 4], [5, 6]], 4)

Expected Output: [[5, 8]]

Explanation: Edge case: there are more devices than rows and columns, so several shards are empty but the final answer is still correct.

Input: ([[2, 3], [4, 5]], [[6], [7]], 1)

Expected Output: [[33], [59]]

Explanation: Edge case: with one device, the sharded algorithm becomes standard matrix multiplication.

Hints

  1. Each rotating B shard needs to carry its original column-shard index so you know where to place the partial product.
  2. Let each device build a full-width local result block and fill only the columns covered by the shard it currently holds.
Last updated: Apr 22, 2026

Loading coding console...

PracHub

Master your tech interviews with 8,000+ real questions from top companies.

Product

  • Questions
  • Learning Tracks
  • Interview Guides
  • Resources
  • Premium
  • For Universities
  • Student Access

Browse

  • By Company
  • By Role
  • By Category
  • Topic Hubs
  • SQL Questions
  • Compare Platforms
  • Discord Community

Support

  • support@prachub.com
  • (916) 541-4762

Legal

  • Privacy Policy
  • Terms of Service
  • About Us

© 2026 PracHub. All rights reserved.

Related Coding Questions

  • Flatten and unflatten nested Python structures - xAI (nan)
  • Compute dasher pay from order events - xAI (nan)
  • Compute total active time per Twitter Space - xAI (medium)
  • Design a Recoverable Iterator - xAI (medium)
  • Find kth element and sliding-window kth in stream - xAI (hard)