-
Notifications
You must be signed in to change notification settings - Fork 546
Add torch compile logic to LoRA trainer #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughPreprocessedLoRAModule now force-clones parameter tensors before LoRA injection to avoid inference-locked states, and the DiT decoder is conditionally wrapped with torch.compile(mode="default") when torch.compile exists and the device is CUDA; previously a warning was logged instead. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@acestep/training/trainer.py`:
- Around line 137-141: The unconditional param.data.clone() can spike VRAM;
instead, perform cloning only for tensors that are in inference mode and do it
inside a no-grad context: inside the block guarded by check_peft_available(),
wrap the loop with torch.no_grad() (or torch.inference_mode()) and for each
param in model.parameters() check the tensor’s inference state (e.g.,
torch.is_inference(param.data) or equivalent predicate) before cloning — only
call param.data.clone() for those inference-mode tensors to avoid doubling
memory for all parameters. Ensure you reference check_peft_available(),
model.parameters(), param.data.clone(), and use
torch.no_grad()/torch.inference_mode() in the updated code.
acestep/training/trainer.py
Outdated
| if check_peft_available(): | ||
| # Fix: Force tensors out of inference mode before injection | ||
| for param in model.parameters(): | ||
| param.data = param.data.clone() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, check the actual code at the specified location
cat -n acestep/training/trainer.py | sed -n '130,155p'Repository: ace-step/ACE-Step-1.5
Length of output: 1640
🏁 Script executed:
# Check PyTorch version requirements in the project
find . -maxdepth 2 \( -name "requirements*.txt" -o -name "pyproject.toml" -o -name "setup.py" -o -name "setup.cfg" \) -exec cat {} \;Repository: ace-step/ACE-Step-1.5
Length of output: 7880
🌐 Web query:
torch.is_inference PyTorch API detection inference mode tensors
💡 Result:
To detect inference-mode tensors in PyTorch, use the tensor method:
t.is_inference()→ bool: returns whether a specific tensor is an inference tensor. (docs.pytorch.org)
To detect whether your code is currently running under inference mode (regardless of a particular tensor), use:
torch.is_inference_mode_enabled()→ bool (docs.pytorch.org)
Key semantics (why a tensor might be “inference”):
- Tensors newly allocated inside
torch.inference_mode()are marked as inference tensors, and views of them are also inference tensors. (docs.pytorch.org) torch.inference_mode()is liketorch.no_grad()but goes further for performance by disabling view tracking and version counter bumps, and is more restrictive (inference tensors shouldn’t be used in autograd-recorded computations). (docs.pytorch.org)
Minimal example:
import torch
x = torch.ones(3, requires_grad=True)
print(torch.is_inference_mode_enabled()) # False
with torch.inference_mode():
y = x * 2
print(torch.is_inference_mode_enabled()) # True
print(y.is_inference()) # True
print(torch.is_inference_mode_enabled()) # False
print(y.is_inference()) # True (still an inference tensor)Citations:
- 1: https://docs.pytorch.org/docs/stable/generated/torch.Tensor.is_inference.html?utm_source=openai
- 2: https://docs.pytorch.org/docs/stable/generated/torch.is_inference_mode_enabled.html?utm_source=openai
- 3: https://docs.pytorch.org/cppdocs/notes/inference_mode.html?utm_source=openai
- 4: https://docs.pytorch.org/docs/stable/generated/torch.autograd.grad_mode.inference_mode.html?utm_source=openai
Avoid unconditional parameter cloning to prevent VRAM spikes/OOM.
Cloning every parameter can briefly double memory for large DiT models and may OOM even when no inference-mode tensors exist. Gate the clone to inference-mode tensors only and do it under no_grad to minimize risk.
🔧 Suggested fix (guarded clone)
for param in model.parameters():
- param.data = param.data.clone()
+ if param.is_inference():
+ with torch.no_grad():
+ param.data = param.data.clone()🤖 Prompt for AI Agents
In `@acestep/training/trainer.py` around lines 137 - 141, The unconditional
param.data.clone() can spike VRAM; instead, perform cloning only for tensors
that are in inference mode and do it inside a no-grad context: inside the block
guarded by check_peft_available(), wrap the loop with torch.no_grad() (or
torch.inference_mode()) and for each param in model.parameters() check the
tensor’s inference state (e.g., torch.is_inference(param.data) or equivalent
predicate) before cloning — only call param.data.clone() for those
inference-mode tensors to avoid doubling memory for all parameters. Ensure you
reference check_peft_available(), model.parameters(), param.data.clone(), and
use torch.no_grad()/torch.inference_mode() in the updated code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@acestep/training/trainer.py`:
- Around line 152-158: The current logic logs a misleading warning when
compilation is skipped for non-CUDA devices and doesn't handle runtime failures
from torch.compile; change the flow to first check for torch.compile
availability and second check device (self.device_type == "cuda"), log distinct
messages for "torch.compile not available" vs "skipping compile on non-CUDA
device", and wrap the call to torch.compile(self.model.decoder, mode="default")
in a try/except that catches exceptions, logs the exception via logger.error
with context, and restores the original self.model.decoder (or leaves it
unchanged) as a safe fallback so runtime compile failures won't break training;
reference torch.compile, self.device_type, self.model.decoder, and logger in
your changes.
| # Added Torch Compile Logic | ||
| if hasattr(torch, "compile") and self.device_type == "cuda": | ||
| logger.info("Compiling DiT decoder...") | ||
| self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA | ||
| logger.info("torch.compile successful") | ||
| else: | ||
| logger.warning("torch.compile is not available on this PyTorch version.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# Read the specific file and lines mentioned in the review
sed -n '140,165p' acestep/training/trainer.pyRepository: ace-step/ACE-Step-1.5
Length of output: 1198
🌐 Web query:
torch.compile failure modes exceptions runtime errors fallback pattern
💡 Result:
Common torch.compile failure modes (what you’ll see)
- Graph breaks (not always an “error”)
- With default
fullgraph=False, Dynamo will graph-break, run the unsupported part in eager Python, then resume tracing (often just causing slowness, not failure). Typical causes: data-dependent control flow, many Python builtins/C functions, printing/logging, etc. [1] - If you force
fullgraph=True, then any graph break becomes a hard failure (you’ll get an error instead of falling back to eager for that region). [2]
- Dynamo “unsupported” / tracing-time exceptions
- You’ll commonly see
torch._dynamo.exc.Unsupported: ...when Dynamo can’t trace something (examples include.item()-driven data dependence, dynamic slicing with Tensor arguments, unsupported Python constructs, skipped/marked functions, etc.). [1][3]
- Backend compiler failures (Inductor/Triton/etc.)
- Tracing may succeed, but the backend codegen/compile fails and Dynamo raises
torch._dynamo.exc.BackendCompilerFailedwrapping the real underlying error (e.g., toolchain/Triton compilation problems). [4]
- Recompilations / guard failures (perf pathology)
- Not necessarily a crash, but you can get lots of recompiles when guards fail (often due to shape/stride/dtype changes). This can look like “it works but is extremely slow / keeps recompiling.” [1]
“Fail fast” vs “fallback” knobs (practical patterns)
A) Prefer partial compilation (default) to avoid hard failures
opt = torch.compile(model, fullgraph=False) # default behavior: graph breaks allowedGraph breaks will run those regions eagerly instead of erroring. [1][2]
B) Explicitly exclude known-problem regions (most robust)
Use @torch.compiler.disable to force eager execution for specific functions (and keep the rest compiled). [1]
import torch
`@torch.compiler.disable`
def troublesome_bits(x):
...
`@torch.compile`
def f(x):
x = ...
x = troublesome_bits(x) # graph break + eager for this call
return xC) Last-resort “don’t crash, just run eager”
If you hit compiler crashes and want to keep training/inference running, set:
import torch._dynamo
torch._dynamo.config.suppress_errors = TrueThis suppresses compiler exceptions and falls back to eager for the failing compile attempt (not considered best practice long-term). [1][4]
D) Force errors on graph breaks (good for debugging / CI)
fullgraph=True(strongest) errors on any graph break. [2]torch._dynamo.error_on_graph_break(...)controls whether graph breaks error, butfullgraph=Truetakes precedence. [3]
E) Escape hatch for “I know this is safe to trace non-strictly”
Use torch._dynamo.nonstrict_trace when a function graph-breaks but is still traceable under non-strict rules. [5]
How to identify what’s happening
- Use
TORCH_LOGS="graph_breaks"to see where and why graph breaks occur;TORCH_LOGS="recompiles"to diagnose recompilation churn. [1] - Use
tlparsefor a higher-level compilation report (graph breaks, recompiles, errors). [1]
Sources
[1] PyTorch docs — torch.compile Troubleshooting (graph breaks, suppress_errors, disable, recompiles, logging)
[2] PyTorch docs — Use fullgraph=True to Identify and Eliminate Graph Breaks
[3] PyTorch docs — Toggling error_on_graph_break
[4] PyTorch GitHub issue — example torch._dynamo.exc.BackendCompilerFailed + note about suppress_errors fallback
[5] PyTorch docs — Use torch._dynamo.nonstrict_trace
Fix misleading warning message and handle torch.compile edge cases safely.
The warning message is inaccurate—it claims torch.compile is unavailable when the check actually skips compilation due to non-CUDA devices, not missing PyTorch support. Additionally, while torch.compile with default fullgraph=False gracefully handles most graph breaks by falling back to eager execution, backend or dynamic compilation failures can still occur at runtime without warning.
🔧 Suggested fix (clearer log + safe fallback)
- if hasattr(torch, "compile") and self.device_type == "cuda":
- logger.info("Compiling DiT decoder...")
- self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
- logger.info("torch.compile successful")
- else:
- logger.warning("torch.compile is not available on this PyTorch version.")
+ if hasattr(torch, "compile") and self.device_type == "cuda":
+ logger.info("Compiling DiT decoder...")
+ try:
+ self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
+ except Exception as exc:
+ logger.warning(f"torch.compile failed; continuing without compilation: {exc}")
+ else:
+ logger.info("torch.compile successful")
+ else:
+ logger.info("Skipping torch.compile (requires CUDA and torch.compile support).")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Added Torch Compile Logic | |
| if hasattr(torch, "compile") and self.device_type == "cuda": | |
| logger.info("Compiling DiT decoder...") | |
| self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA | |
| logger.info("torch.compile successful") | |
| else: | |
| logger.warning("torch.compile is not available on this PyTorch version.") | |
| # Added Torch Compile Logic | |
| if hasattr(torch, "compile") and self.device_type == "cuda": | |
| logger.info("Compiling DiT decoder...") | |
| try: | |
| self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA | |
| except Exception as exc: | |
| logger.warning(f"torch.compile failed; continuing without compilation: {exc}") | |
| else: | |
| logger.info("torch.compile successful") | |
| else: | |
| logger.info("Skipping torch.compile (requires CUDA and torch.compile support).") |
🤖 Prompt for AI Agents
In `@acestep/training/trainer.py` around lines 152 - 158, The current logic logs a
misleading warning when compilation is skipped for non-CUDA devices and doesn't
handle runtime failures from torch.compile; change the flow to first check for
torch.compile availability and second check device (self.device_type == "cuda"),
log distinct messages for "torch.compile not available" vs "skipping compile on
non-CUDA device", and wrap the call to torch.compile(self.model.decoder,
mode="default") in a try/except that catches exceptions, logs the exception via
logger.error with context, and restores the original self.model.decoder (or
leaves it unchanged) as a safe fallback so runtime compile failures won't break
training; reference torch.compile, self.device_type, self.model.decoder, and
logger in your changes.
For faster lora compilation
Summary by CodeRabbit
Bug Fixes
Performance