Design Task: Key–Value Cache for Transformer Decoder Inference
Context
You are building an autoregressive inference engine for a Transformer decoder-only model. To avoid recomputing self-attention over the full prefix at each decoding step, implement key–value (K/V) caching at the per-layer level.
Assume a standard multi-head self-attention decoder with:
-
Batch size B (or effective batch B_eff when using beams)
-
Model dimension d_model, heads H, head dimension d_k = d_model / H
-
Max generation length L_max
-
Mixed precision (float16/bfloat16) support
Requirements
Design and implement K/V caching that:
-
Specifies clear tensor shapes for prefill (prompt) and per-step decode for both greedy and beam search.
-
Defines an API to maintain per-layer caches across decoding steps and batched inputs.
-
Handles cache growth and memory limits (preallocation and an option for paged/block allocation).
-
Supports greedy search and beam search (expand/reorder caches per step, EOS handling).
-
Includes complexity analysis and expected speedups versus recomputation.
-
Includes tests for correctness, including edge cases (e.g., EOS in the middle of a batch, variable prompt lengths, beam reorder correctness).
Provide a teaching-oriented solution with step-by-step reasoning, formulas where helpful, and code-style pseudocode for the API and critical paths.