Debugging Plan: PyTorch Transformer Text Model with Mask Errors, Metric Plateau, AMP Crashes, and Nondeterminism
Context
You are training a Transformer-based text model in PyTorch for a sequence task (e.g., causal language modeling, sequence classification, or token classification). The model shows four symptoms:
-
Occasional CUDA shape/index errors around attention masks.
-
Validation metrics plateau near chance while training loss decreases.
-
Intermittent crashes when using mixed precision (AMP) and gradient accumulation.
-
Nondeterministic results across runs.
Assume a standard training stack: PyTorch, Hugging Face–style tokenization, DataLoader(s), CrossEntropy loss variants (ignore_index, label smoothing, class weights), AdamW + scheduler, optional DDP, AMP, and gradient clipping.
Task
Propose a systematic, end-to-end debugging plan to localize and resolve all four issues. For each area below, specify concrete checks/experiments, describe the failure signal(s), outline a minimal reproducible example or unit test, and state how you would implement and verify the fix:
-
Data preprocessing: padding, truncation, label alignment
-
Tokenization and attention/causal masks
-
Positional encodings
-
Loss computation: ignore_index, label smoothing, class weights
-
Optimizer/scheduler/zero_grad/gradient clipping
-
AMP/GradScaler settings and gradient accumulation
-
Seed control and deterministic kernels
-
DDP and Sampler configuration
Explain how each suspected bug would manifest, how you’d isolate it, and how to confirm the fix.