-
Notifications
You must be signed in to change notification settings - Fork 546
Enhance MPS LoRA training by enabling gradient checkpointing in DiT training path #401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…2 trainable params, and non-finite grad guards
📝 WalkthroughWalkthroughThe changes enhance LoRA training robustness by unwrapping DiT decoders from PEFT wrapper layers and adding FP32 stability mechanisms to the training pipeline. New utilities enforce gradient handling, detect non-finite gradients, and configure memory optimization features with safe error handling across wrapped modules. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsGit: Failed to clone repository. Please run the Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
acestep/training/trainer.py (1)
579-640:⚠️ Potential issue | 🟠 MajorRe-apply FP32 optimizer fix after checkpoint resume.
_ensure_optimizer_params_fp32converts parameter tensors to fp32 but not optimizer state tensors. Whenload_training_checkpointrestores optimizer state saved during mixed-precision training, the state tensors (momentum, second moments) remain in fp16, undoing the initial dtype fix and risking training instability.After a successful optimizer load from checkpoint, re-run the dtype fix and explicitly cast optimizer state tensors to fp32:
🔧 Suggested fix (post-resume fp32 enforcement)
checkpoint_info = load_training_checkpoint( resume_from, optimizer=optimizer, scheduler=scheduler, device=self.module.device, ) + if checkpoint_info.get("loaded_optimizer"): + casted_opt_params, total_opt_params = _ensure_optimizer_params_fp32(optimizer) + logger.info( + f"Optimizer param dtype fixup (post-resume): casted {casted_opt_params}/{total_opt_params} to fp32" + ) + for state in optimizer.state.values(): + for key, value in state.items(): + if torch.is_tensor(value) and value.is_floating_point() and value.dtype != torch.float32: + with torch.no_grad(): + state[key] = value.float()

This should get peer review before merge because it touches core training control flow and wrapper interactions.
It looks to be running correctly on my M4 Max 64 GB and MPS improvements seems promising, but area of change just makes me a little cautious. The reduction of memory pressure and the GPU utilization improvement and throughput increases though are great if this didn't break anything.
Primary change:
Add explicit grad-carrying fallback input for checkpointed forward (xt.requires_grad_(True)), since this DiT model does not implement the LM-style embedding hook path used by PEFT (get_input_embeddings).
Supporting changes required for stability:
PEFT/DiT compatibility guard around enable_input_require_grads.
Decoder unwrap before reinjecting LoRA after failed/partial wrapped states.
MPS AMP + dtype stabilization (ensure trainable/optimizer params remain fp32 where required).
Non-finite gradient guard to skip unstable steps instead of crashing.
Logging cleanup for expected DiT-specific checkpointing behavior.
Notes:
Intended behavior change is primarily for MPS training robustness; non-MPS paths should be functionally unchanged except for safer guards.
Summary by CodeRabbit
Bug Fixes
New Features