You are given a small decoder-only transformer implementation for autoregressive language modeling.
Part 1: Debugging
The training code contains four bugs. Fix them so that training loss decreases and generated text becomes reasonable. The known bug categories are:
-
The language-model loss uses the wrong label shift.
-
Positional embeddings are initialized incorrectly.
-
The attention mask is wrong.
-
There is one additional typo-related bug somewhere in the code.
After your fixes, the model should train normally, the loss should go down, and sample outputs should look plausible.
Part 2: KV cache
A cache class for attention keys and values is provided. Extend the model to support cached autoregressive decoding.
Update the implementation so that:
-
The attention module reads from and appends to the cache correctly.
-
Causal masking is only applied when there is no cache or the cache is empty.
-
Positional encodings use the correct offset based on the current cache length.
-
The
generate
function performs incremental decoding with the cache.
Finally, verify that deterministic generation with and without KV cache produces the same token sequence.