Distributed GPU Computation And Parallel ML
Asked of: Machine Learning Engineer
Last updated
What's being tested
Candidates must show practical mastery of distributing dense linear algebra across multiple accelerators: partitioning matrices, minimizing inter-GPU communication, and overlapping compute with transfers. Interviewers probe the candidate’s ability to trade compute vs communication, pick appropriate collective primitives, and reason about memory, numerical precision, and throughput bottlenecks in an ML training-style workload. Google cares because inefficient parallelization destroys wall-clock performance and resource utilization on large GPU clusters used for model training and inference kernels.
Core knowledge
-
GPU memory hierarchy and interconnects: device DRAM, pinned memory, PCIe, NVLink/NVSwitch, and host NICs; bandwidth and latency differences determine communication strategy and whether to use host staging or GPUDirect.
-
Matrix multiplication cost: for multiplying two n×n matrices, compute ≈ 2n^3 FLOPs; communication lower bounds matter when partitioning—minimize data movement relative to FLOPs.
-
Latency–bandwidth model: model message cost as where α is per-message latency and β is inverse bandwidth; small many-messages hurt (α-dominated), large bulk transfers hurt (β-dominated).
-
Partitioning families: 1D row/column block (easy, high communication), 2D block-cyclic / SUMMA (balance compute and communication), and Cannon’s algorithm (works well on torus with neighbor sends). Know tradeoffs in replication, memory, and message count.
-
Collectives and primitives: use
AllReducefor reductions,AllGatherfor assembling blocks, andBroadcastfor distributing tiles—preferNCCL/MPIoptimized primitives to implement these with minimal overhead. -
Overlap strategy: use CUDA streams,
cudaMemcpyAsync, and asynchronous collectives to overlap compute kernels with transfers; schedule a pipeline of tiles to hide α latency. -
Memory and tiling: choose tile size to fit L2 / shared memory and saturate tensor cores; tune block size to get high arithmetic intensity and avoid exceeding per-GPU memory limits.
-
Precision and stability: use mixed precision (FP16/FP32 accumulation) and
bfloat16where supported; guard against catastrophic rounding by keeping accumulators in higher precision. -
Kernel libraries: leverage optimized vendors like
cuBLAS/cuBLASLtfor per-GPU GEMM, and fallback to custom kernels only when communication patterns require special tiling. -
Scalability tipping points: beyond GPUs, network topology and collectives become dominant; prefer hierarchical schemes (intra-node
NVLink+ inter-nodeRDMA/NCCL) and reduce synchronization frequency (gradient accumulation). -
Profiling and metrics: measure p99 transfer times, device utilization, achieved TFLOPS, PCIe vs NVLink utilization (
nvprof/nsight), and collectives' overlap efficiency to find hotspots. -
Practical constraints: account for batch size when mapping to GEMM shapes (rectangular matrices), support batched-GEMM for many small multiplies, and consider operator fusion to reduce memory traffic.
Worked example — Design multi-GPU matrix multiplication
First 30s: ask matrix sizes (square or rectangular), GPU count and topology (intra-node NVLink? multi-node with RDMA?), memory per device, and whether output must be memory-consistent across GPUs or can be sharded. Skeleton: (1) choose partitioning (1D vs 2D) based on P and memory; (2) pick communication pattern (SUMMA with AllGather/Broadcast or Cannon with neighbor sends); (3) implement per-tile compute with cuBLAS and overlap with cudaMemcpyAsync and asynchronous collectives; (4) tune tile size and precision. A common tradeoff: 2D blocking (SUMMA) lowers total communication per GPU but requires synchronized broadcasts per panel—if α is high, prefer larger tiles and fewer rounds or use neighbor-based Cannon to reduce messages. Close by proposing measurements: profile with nvprof, report TFLOPS vs peak, then iterate on tile size, collective library choice (NCCL vs MPI), and consider moving to hierarchical reduce for multi-node. If more time: implement fault-tolerance for node failures and extend to pipelined model-parallel layers.
A second angle
Consider many small independent multiplies (batched-GEMM) instead of one giant matrix multiply: the communication/computation balance flips because kernel launch overhead and small-matrix inefficiency dominate. Strategy changes to aggregate many small multiplies into larger batched kernels (use cublasGemmBatched), prefer 1D partitioning to avoid many tiny messages, and optimize kernel fusion to reduce host-device roundtrips. When shapes are highly rectangular (e.g., tall-and-skinny), use panel-oriented algorithms and reduce frequency of global synchronization, perhaps accumulating partial results locally and finalizing with a single AllReduce.
Common pitfalls
Pitfall: Underestimating communication latency — designing many small broadcasts yields α-dominated cost; instead aggregate tiles or increase tile size to amortize latency. Message count matters as much as bytes; prefer fewer larger transfers.
Pitfall: Ignoring memory replication cost — naive 1D replication reduces messages but may exceed per-GPU memory; quantify memory per tile and enforce strict memory budget in your design.
Pitfall: Treating mixed precision as free — using FP16 without FP32 accumulation or loss scaling risks training divergence; explicitly state accumulation precision and scaling strategy when proposing precision optimizations.
Connections
This topic commonly pivots to distributed training (data vs model parallelism) and communication-avoiding algorithms for linear algebra. Interviewers may also ask about profiling/perf tooling (nvprof/nsight) and hierarchical collectives (intra-node NVLink then inter-node RDMA with NCCL).
Further reading
-
[Communication-Avoiding Algorithms for Linear Algebra (Demmel et al.)] — foundational analysis of communication lower bounds and algorithms like SUMMA/Cannon.
-
NVIDIA Collective Communications Library (
NCCL) Guide — practical primitives and best practices for multi-GPU collectives. -
cuBLAS and cuBLASLt Documentation — tuned GEMM implementations and mixed-precision guidance.
Practice questions
Related concepts
- ML Inference APIs And GPU BatchingML System Design
- Distributed Training Parallelism And CollectivesML System Design
- ML Frameworks, Model Compilation, And ParallelismML System Design
- ML System Design, Recommenders, Forecasting And AllocationMachine Learning
- Distributed Training And LLM Fine-Tuning PlatformsML System Design
- Production ML Serving, Feature Stores, And MonitoringML System Design