Autoregressive Transformer: Correct Attention Masking with Padding
Context: You are implementing decoder self-attention for an autoregressive Transformer where input sequences in a batch are right-padded to a common length. You must prevent attention to future tokens and exclude padded positions.
Task:
-
Implement a causal mask that prevents attending to future tokens.
-
Implement a padding mask that prevents attending to padded positions.
-
Show how to combine and apply these masks in multi-head attention.
-
Explain common bugs, how they appear in loss/accuracy, and how to test for them.
Assumptions:
-
Batch-first tensors: shape (B, T, D).
-
Right padding. An attention_mask of shape (B, T) uses 1 for real tokens and 0 for padding (or equivalently, you have pad_token_id and input_ids).
-
PyTorch is available.