Diagnose a Diverging PyTorch Transformer Training Run
You are given a PyTorch Transformer training pipeline whose loss diverges and validation accuracy remains near random. The repository already includes:
-
Tokenization, padding, and masking
-
An encoder–decoder Transformer with multi-head attention
-
AdamW with weight decay, gradient clipping
-
Cosine LR scheduler with warmup
-
Mixed precision (torch.cuda.amp.GradScaler)
-
Teacher forcing in the decoder
Task
-
Identify four distinct bugs that could plausibly cause divergence or random validation accuracy (e.g., incorrect attention mask construction or dtype/device, off-by-one label shift, missing gradient zeroing, misplaced gradient scaling, wrong positional encoding shape/broadcast, excluding LayerNorm/bias from weight decay incorrectly, data leakage in batching).
-
For each bug:
-
Explain the failure mode (why it breaks learning).
-
Show a minimal code fix.
-
Provide a unit test or runtime assertion to catch it in the future.
-
Propose a debugging plan including sanity checks (e.g., copy task), gradient/activation statistics, NaN detection; define which metrics to monitor; and outline a small experiment to verify each fix.