Skip to content

Conversation

@riversedge
Copy link
Contributor

@riversedge riversedge commented Feb 9, 2026

This should get peer review before merge because it touches core training control flow and wrapper interactions.

It looks to be running correctly on my M4 Max 64 GB and MPS improvements seems promising, but area of change just makes me a little cautious. The reduction of memory pressure and the GPU utilization improvement and throughput increases though are great if this didn't break anything.

Primary change:
Add explicit grad-carrying fallback input for checkpointed forward (xt.requires_grad_(True)), since this DiT model does not implement the LM-style embedding hook path used by PEFT (get_input_embeddings).

Supporting changes required for stability:
PEFT/DiT compatibility guard around enable_input_require_grads.
Decoder unwrap before reinjecting LoRA after failed/partial wrapped states.
MPS AMP + dtype stabilization (ensure trainable/optimizer params remain fp32 where required).
Non-finite gradient guard to skip unstable steps instead of crashing.
Logging cleanup for expected DiT-specific checkpointing behavior.
Notes:

Intended behavior change is primarily for MPS training robustness; non-MPS paths should be functionally unchanged except for safer guards.

Summary by CodeRabbit

  • Bug Fixes

    • Improved handling of model wrappers and gradient operations during training.
    • Added safeguards against non-finite gradients.
  • New Features

    • Enhanced gradient checkpointing and memory optimization configuration.
    • Implemented FP32 precision enforcement for model parameters.
    • Added automatic detection and reporting of training memory features, including gradient checkpointing and cache usage.

…2 trainable params, and non-finite grad guards
@coderabbitai
Copy link

coderabbitai bot commented Feb 9, 2026

📝 Walkthrough

Walkthrough

The changes enhance LoRA training robustness by unwrapping DiT decoders from PEFT wrapper layers and adding FP32 stability mechanisms to the training pipeline. New utilities enforce gradient handling, detect non-finite gradients, and configure memory optimization features with safe error handling across wrapped modules.

Changes

Cohort / File(s) Summary
LoRA Decoder Unwrapping
acestep/training/lora_utils.py
Unwrap PEFT-wrapped DiT decoders by traversing \_forward_module and base_model attributes; collapse nested structures to expose clean nn.Module. Add guards around enable_input_require_grads and is_gradient_checkpointing with safe error handling and logging.
Trainer FP32 & Gradient Utilities
acestep/training/trainer.py (helper functions)
Introduce five new helper functions: cast trainable/optimizer params to FP32, count non-finite gradients, collect wrapper chains, and configure memory features (gradient checkpointing, cache disabling, input-grad capability detection).
Training Flow & Memory Configuration
acestep/training/trainer.py (core flow)
Extend initialization and training loop to configure memory features via new utilities; enforce FP32 stability before/after Fabric setup; detect and skip optimizer steps on non-finite gradients; add status yields for gradient checkpointing, cache, and input-grad availability. Add force_input_grads_for_checkpointing attribute to PreprocessedLoRAModule.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Poem

🐰 Unwrapping PEFT's tangled layers,
We cast to FP32 with care,
Non-finite gradients beware!
Memory optimized, training's fair, 🌟
Robust decoders, stable prayers.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title clearly and specifically summarizes the main change: enabling gradient checkpointing in the DiT training path for MPS LoRA training. It accurately reflects the primary objective and is neither vague nor misleading.
Docstring Coverage ✅ Passed Docstring coverage is 92.31% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Warning

Review ran into problems

🔥 Problems

Git: Failed to clone repository. Please run the @coderabbitai full review command to re-trigger a full review. If the issue persists, set path_filters to include or exclude specific files.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
acestep/training/trainer.py (1)

579-640: ⚠️ Potential issue | 🟠 Major

Re-apply FP32 optimizer fix after checkpoint resume.

_ensure_optimizer_params_fp32 converts parameter tensors to fp32 but not optimizer state tensors. When load_training_checkpoint restores optimizer state saved during mixed-precision training, the state tensors (momentum, second moments) remain in fp16, undoing the initial dtype fix and risking training instability.

After a successful optimizer load from checkpoint, re-run the dtype fix and explicitly cast optimizer state tensors to fp32:

🔧 Suggested fix (post-resume fp32 enforcement)
                 checkpoint_info = load_training_checkpoint(
                     resume_from,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     device=self.module.device,
                 )
+                if checkpoint_info.get("loaded_optimizer"):
+                    casted_opt_params, total_opt_params = _ensure_optimizer_params_fp32(optimizer)
+                    logger.info(
+                        f"Optimizer param dtype fixup (post-resume): casted {casted_opt_params}/{total_opt_params} to fp32"
+                    )
+                    for state in optimizer.state.values():
+                        for key, value in state.items():
+                            if torch.is_tensor(value) and value.is_floating_point() and value.dtype != torch.float32:
+                                with torch.no_grad():
+                                    state[key] = value.float()

@riversedge
Copy link
Contributor Author

FWIW I successfully completed a 30 min source audio LoRA training run with this code and tested out several songs prompts. While certain parts of the output quality from that run were perhaps lower than desired, it was clear that the training had worked and adopted the source song styles successfully. Each epoch run took apps 34s on MPS Mac M4 Max, memory footprint was stable, training did converge nicely (see below). Total run about 8 hrs for 1000 epoch (vs mainline branch taking about 80 hours).

Screenshot 2026-02-09 at 10 16 19 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant