Debugging a Transformer That Intermittently Throws Shape/Type Errors and Fails to Converge
You are given a Transformer-based sequence model that:
-
Intermittently raises shape or dtype mismatch errors during training.
-
Fails to converge after several thousand steps (loss stalls or diverges).
Describe your end-to-end debugging approach. Include concrete checks, metrics, and code-level instrumentation you would add.
Cover These Areas
-
Reproducibility and failure capture
-
Seed control, logging, anomaly detection, and making the error reproducible.
-
Tokenization, padding, and special tokens
-
Validate tokenizer round-trips, PAD/BOS/EOS/UNK handling, truncation rules, and padding policy.
-
Attention masks
-
Verify causal vs. bidirectional masks, mask shapes/dtypes/values, and mask application order.
-
Positional embeddings
-
Check absolute vs. rotary/sinusoidal positions, max context length, and KV-cache position handling.
-
Assertions and unit tests
-
Add assertions for tensor shapes/dtypes/sequence lengths; create unit tests for collate functions and model forward.
-
Language modeling targets
-
Verify label shifting for decoder-only LM, ignore_index for PAD, and BOS/EOS conventions.
-
Gradients and optimization
-
Detect exploding/vanishing gradients, gradient norms, clipping, optimizer hyperparameters, LR schedule, and mixed precision.
-
Numerical stability and NaNs
-
Diagnose NaN sources, softmax masking in FP16, loss scaling, epsilon settings, and safe reductions.
-
Data bugs
-
Spot dataset corruption, inconsistent tokenizers, unexpected truncation, and distribution shifts.
-
Checkpoint/resume and randomness
-
Validate checkpoint integrity, optimizer/scheduler/scaler state, sampler state, and RNG across workers/ranks.
-
Multi-GPU/DP/ZeRO/AMP issues
-
Synchronization bugs, gradient scaling, unused parameters, grad clipping across shards, and per-rank logging.
-
Throughput and memory profiling
-
Tokens/sec, dataloader stalls, operator-level hotspots, and OOM sources.
-
Minimal reproducible example (MRE)
-
Build a tiny, deterministic setup with targeted logging to localize the fault.
Provide concrete code snippets, assertions, metrics, and logs you would implement.