LLM Inference Optimization And KV Cache
Asked of: ML Engineer
Last updated

What's being tested
Candidates must demonstrate practical knowledge of how Transformer inference is implemented and optimized for production: how an LLM maintains and uses cached attention states to decode tokens efficiently, how that cache affects memory/latency tradeoffs, and how to design around long input contexts (retrieval, chunking, offload). Interviewers probe ability to quantify cost, choose engineering patterns (sharding, batching, dtype), and diagnose correctness/performance bugs under realistic constraints.
Core knowledge
-
KV cache purpose: store per-layer keys and values so autoregressive steps reuse previous attention computations, reducing per-step cost from O(L^2) to O(L) for a single token; essential for low-latency streaming generation.
-
Memory formula: for model dim D, L cached tokens, N layers, using float32, KV bytes ≈ 2 * N * L * D * sizeof(float). With float16/FP16, memory halves; with INT8 quantization, roughly quarter.
-
Per-step compute cost with cache: per new token, attention cost per layer is O(L * D). Total decode cost over T tokens is O(N * D * (L_avg + T) * T) depending on batching and prefix sharing.
-
Positional embeddings interactions: absolute positions require correct position index when appending tokens; RoPE (rotary embeddings) and relative schemes require attention to how positions are incremented or offset when truncating or offloading.
-
Layout and dtype matter: choose contiguous memory layout per-layer (e.g.,
[N, L, D]or[L, N, D]) to enable fused kernels; using FP16 or INT8 reduces memory but demands numerical checks for underflow/NaN. -
Incremental attention primitives: some fast kernels (e.g., flash attention) accelerate batched full-attention but may not support easy incremental caching; validate support for single-step incremental decode or use specialized incremental implementations.
-
Sharding strategies: tensor parallelism shards weights, pipeline parallelism shards layers; KV sharding splits cached states across devices to stay within GPU memory. Shard choice affects cross-device latency and communication patterns.
-
Offload patterns: keep recent tokens' KV on GPU, older parts on CPU/NVMe with async prefetch; eviction via simple truncation or summarize/compact older context (embedding compression) to bound memory.
-
Beam search and cache duplication: beams with shared prefix can share KV; divergent beams require separate copies—plan memory reservations accordingly and reuse where possible to reduce duplication.
-
Correctness checks: reproduce logits between cached and uncached runs for same input; assert KV shapes, position indices, and dtype; checksum or hash per-layer KV to detect corruption.
-
Common optimizations: fused attention+softmax kernels, batched decoding for many concurrent requests, micro-batching of tokens, pinned host memory and zero-copy transfers between CPU/GPU (
cudaHostAlloc), and model conversion toTensorRTorONNXfor optimized kernels. -
Tradeoff quantification: doubling context length L doubles KV memory and per-step attention work; beyond certain L, prefer retrieval/RAG or sparse attention models (e.g., Longformer) to keep
p99latency bounded.
Worked example — Explain KV cache in Transformer inference
Start by clarifying: "Do you expect single-token incremental decoding (streaming) and what decoder type—sampling or beam search?" Then outline three pillars: (1) describe what the KV cache stores and why it reduces recomputation, (2) give the memory and compute formulas and a numeric example for a typical model (e.g., N=48, D=12288, L=2048), and (3) list implementation choices (dtype, layout, device placement, and sharding). Show tradeoff: using FP16 halves memory but requires numeric validation; sharding KV across GPUs reduces per-GPU memory but adds all-reduce/communication overhead increasing latency. Flag edge cases: positional embeddings (RoPE vs absolute) and beam reuse semantics. Close with testing/diagnostics: "I would implement per-layer checksums and a small unit test comparing logits with and without cache; if more time, instrument p50/p99 latency under different batch sizes and add CPU offload path with async prefetch."
A second angle — Design LLM search handling long token inputs
The same KV/cache and offload principles apply but the constraints shift: search must ingest long documents efficiently without exceeding context limits. Pillars become (1) chunking and retrieval—only top-K relevant chunks are prefixed into the prompt, (2) summary or compressed representations for older text to avoid unbounded KV growth, and (3) KV placement: store only the active prefix on GPU, keep chunk embeddings and infrequently-used KV on CPU/NVMe with async fetch. A key design decision is whether to extend context via a single long prompt (large KV) or iterative RAG (smaller KV, repeated retrieval). Tradeoffs: longer single-context yields fewer retrieval roundtrips but much higher memory and p99 latency; iterative retrieval reduces memory but requires careful prompt engineering and incremental state management.
Common pitfalls
Pitfall: Underestimating KV memory by forgetting the factor of two for keys + values or confusing
D(model dim) with per-head dim, leading to OOMs in prod. Always compute bytes = 2 * N * L * D * sizeof(dtype).
Pitfall: Claiming "use flash attention" as a blanket win without noting many flash-attention implementations don't support efficient single-step incremental decoding—this breaks streaming or requires different kernel logic.
Pitfall: Reusing KV across beams without handling divergence — tempting optimization that yields deterministic but incorrect outputs; instead detect prefix divergence and duplicate only when necessary.
Connections
Interviewers may pivot to model parallelism (how KV sharding interacts with tensor/pipeline parallelism), or to alternative attention architectures (sparse/linear attention, Longformer) for long contexts. They may also ask about serving-level metrics like p99 latency and throughput tradeoffs between batching and tail latency.
Further reading
-
FlashAttention: Fast and Memory-Efficient Exact Attention — explains fast attention kernels and tradeoffs for incremental decoding.
-
RoFormer / Rotary Positional Embeddings — describes position encoding behavior important for KV index management.
Practice questions
- Explain KV cache in Transformer inferenceOpenAI · Machine Learning Engineer · Onsite · medium
- Design and optimize a RAG systemOpenAI · Machine Learning Engineer · Onsite · hard
- Diagnose Transformer training and inference bugsOpenAI · Machine Learning Engineer · Technical Screen · hard
- Design LLM search handling long token inputsOpenAI · Machine Learning Engineer · Onsite · hard