You are implementing autoregressive inference for a decoder-only Transformer.
-
Explain
what the KV cache is
, what tensors are cached per layer, and how it changes computation during incremental decoding.
-
Describe an implementation plan for KV caching that supports:
-
Variable sequence lengths in a batch
-
Beam search or speculative decoding (where sequences can branch)
-
Long contexts (e.g., 32k–128k tokens)
-
Discuss key performance considerations:
-
Memory layout and writes/reads when appending new K/V
-
Avoiding reallocation/copies
-
Interaction with fused attention kernels (e.g., FlashAttention-style)
-
Precision choices (fp16/bf16/int8) for cache
-
What are common bugs or correctness pitfalls when adding a KV cache (masking, position encodings/RoPE, shape mismatches, etc.)?