Skip to content

Conversation

@fidel1234xdd
Copy link

@fidel1234xdd fidel1234xdd commented Feb 10, 2026

For faster lora compilation

Summary by CodeRabbit

  • Bug Fixes

    • Improved parameter stability during training by adding safeguards to ensure parameters are in a usable state before adaptation.
  • Performance

    • Enabled optional runtime compilation for CUDA GPUs to accelerate decoding paths when supported.
    • Improved fallback behavior on systems without compilation support (no impact on stability).

@coderabbitai
Copy link

coderabbitai bot commented Feb 10, 2026

📝 Walkthrough

Walkthrough

PreprocessedLoRAModule now force-clones parameter tensors before LoRA injection to avoid inference-locked states, and the DiT decoder is conditionally wrapped with torch.compile(mode="default") when torch.compile exists and the device is CUDA; previously a warning was logged instead.

Changes

Cohort / File(s) Summary
LoRA Injection Safeguard & Torch Compile Optimization
acestep/training/trainer.py
Force-clone parameter tensors out of inference mode prior to LoRA injection; after injection, wrap DiT decoder with torch.compile(mode="default") on CUDA when available, replacing prior warning-only behavior.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

🐰 I nudged the tensors, soft and light,
Cloned their hum before LoRA's bite.
On CUDA nights I call compile's tune,
Decoders hum beneath the moon.
Hooray for safer, speedier flight!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main change: adding torch.compile logic to the LoRA trainer for performance improvement.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@acestep/training/trainer.py`:
- Around line 137-141: The unconditional param.data.clone() can spike VRAM;
instead, perform cloning only for tensors that are in inference mode and do it
inside a no-grad context: inside the block guarded by check_peft_available(),
wrap the loop with torch.no_grad() (or torch.inference_mode()) and for each
param in model.parameters() check the tensor’s inference state (e.g.,
torch.is_inference(param.data) or equivalent predicate) before cloning — only
call param.data.clone() for those inference-mode tensors to avoid doubling
memory for all parameters. Ensure you reference check_peft_available(),
model.parameters(), param.data.clone(), and use
torch.no_grad()/torch.inference_mode() in the updated code.

Comment on lines 137 to 141
if check_peft_available():
# Fix: Force tensors out of inference mode before injection
for param in model.parameters():
param.data = param.data.clone()

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check the actual code at the specified location
cat -n acestep/training/trainer.py | sed -n '130,155p'

Repository: ace-step/ACE-Step-1.5

Length of output: 1640


🏁 Script executed:

# Check PyTorch version requirements in the project
find . -maxdepth 2 \( -name "requirements*.txt" -o -name "pyproject.toml" -o -name "setup.py" -o -name "setup.cfg" \) -exec cat {} \;

Repository: ace-step/ACE-Step-1.5

Length of output: 7880


🌐 Web query:

torch.is_inference PyTorch API detection inference mode tensors

💡 Result:

To detect inference-mode tensors in PyTorch, use the tensor method:

  • t.is_inference() → bool: returns whether a specific tensor is an inference tensor. (docs.pytorch.org)

To detect whether your code is currently running under inference mode (regardless of a particular tensor), use:

Key semantics (why a tensor might be “inference”):

  • Tensors newly allocated inside torch.inference_mode() are marked as inference tensors, and views of them are also inference tensors. (docs.pytorch.org)
  • torch.inference_mode() is like torch.no_grad() but goes further for performance by disabling view tracking and version counter bumps, and is more restrictive (inference tensors shouldn’t be used in autograd-recorded computations). (docs.pytorch.org)

Minimal example:

import torch

x = torch.ones(3, requires_grad=True)
print(torch.is_inference_mode_enabled())  # False

with torch.inference_mode():
    y = x * 2
    print(torch.is_inference_mode_enabled())  # True
    print(y.is_inference())                   # True

print(torch.is_inference_mode_enabled())  # False
print(y.is_inference())                   # True (still an inference tensor)

Citations:


Avoid unconditional parameter cloning to prevent VRAM spikes/OOM.

Cloning every parameter can briefly double memory for large DiT models and may OOM even when no inference-mode tensors exist. Gate the clone to inference-mode tensors only and do it under no_grad to minimize risk.

🔧 Suggested fix (guarded clone)
            for param in model.parameters():
-               param.data = param.data.clone() 
+               if param.is_inference():
+                   with torch.no_grad():
+                       param.data = param.data.clone()
🤖 Prompt for AI Agents
In `@acestep/training/trainer.py` around lines 137 - 141, The unconditional
param.data.clone() can spike VRAM; instead, perform cloning only for tensors
that are in inference mode and do it inside a no-grad context: inside the block
guarded by check_peft_available(), wrap the loop with torch.no_grad() (or
torch.inference_mode()) and for each param in model.parameters() check the
tensor’s inference state (e.g., torch.is_inference(param.data) or equivalent
predicate) before cloning — only call param.data.clone() for those
inference-mode tensors to avoid doubling memory for all parameters. Ensure you
reference check_peft_available(), model.parameters(), param.data.clone(), and
use torch.no_grad()/torch.inference_mode() in the updated code.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@acestep/training/trainer.py`:
- Around line 152-158: The current logic logs a misleading warning when
compilation is skipped for non-CUDA devices and doesn't handle runtime failures
from torch.compile; change the flow to first check for torch.compile
availability and second check device (self.device_type == "cuda"), log distinct
messages for "torch.compile not available" vs "skipping compile on non-CUDA
device", and wrap the call to torch.compile(self.model.decoder, mode="default")
in a try/except that catches exceptions, logs the exception via logger.error
with context, and restores the original self.model.decoder (or leaves it
unchanged) as a safe fallback so runtime compile failures won't break training;
reference torch.compile, self.device_type, self.model.decoder, and logger in
your changes.

Comment on lines +152 to +158
# Added Torch Compile Logic
if hasattr(torch, "compile") and self.device_type == "cuda":
logger.info("Compiling DiT decoder...")
self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
logger.info("torch.compile successful")
else:
logger.warning("torch.compile is not available on this PyTorch version.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Read the specific file and lines mentioned in the review
sed -n '140,165p' acestep/training/trainer.py

Repository: ace-step/ACE-Step-1.5

Length of output: 1198


🌐 Web query:

torch.compile failure modes exceptions runtime errors fallback pattern

💡 Result:

Common torch.compile failure modes (what you’ll see)

  1. Graph breaks (not always an “error”)
  • With default fullgraph=False, Dynamo will graph-break, run the unsupported part in eager Python, then resume tracing (often just causing slowness, not failure). Typical causes: data-dependent control flow, many Python builtins/C functions, printing/logging, etc. [1]
  • If you force fullgraph=True, then any graph break becomes a hard failure (you’ll get an error instead of falling back to eager for that region). [2]
  1. Dynamo “unsupported” / tracing-time exceptions
  • You’ll commonly see torch._dynamo.exc.Unsupported: ... when Dynamo can’t trace something (examples include .item()-driven data dependence, dynamic slicing with Tensor arguments, unsupported Python constructs, skipped/marked functions, etc.). [1][3]
  1. Backend compiler failures (Inductor/Triton/etc.)
  • Tracing may succeed, but the backend codegen/compile fails and Dynamo raises torch._dynamo.exc.BackendCompilerFailed wrapping the real underlying error (e.g., toolchain/Triton compilation problems). [4]
  1. Recompilations / guard failures (perf pathology)
  • Not necessarily a crash, but you can get lots of recompiles when guards fail (often due to shape/stride/dtype changes). This can look like “it works but is extremely slow / keeps recompiling.” [1]

“Fail fast” vs “fallback” knobs (practical patterns)

A) Prefer partial compilation (default) to avoid hard failures

opt = torch.compile(model, fullgraph=False)  # default behavior: graph breaks allowed

Graph breaks will run those regions eagerly instead of erroring. [1][2]

B) Explicitly exclude known-problem regions (most robust)

Use @torch.compiler.disable to force eager execution for specific functions (and keep the rest compiled). [1]

import torch

`@torch.compiler.disable`
def troublesome_bits(x):
    ...

`@torch.compile`
def f(x):
    x = ...
    x = troublesome_bits(x)  # graph break + eager for this call
    return x

C) Last-resort “don’t crash, just run eager”

If you hit compiler crashes and want to keep training/inference running, set:

import torch._dynamo
torch._dynamo.config.suppress_errors = True

This suppresses compiler exceptions and falls back to eager for the failing compile attempt (not considered best practice long-term). [1][4]

D) Force errors on graph breaks (good for debugging / CI)

  • fullgraph=True (strongest) errors on any graph break. [2]
  • torch._dynamo.error_on_graph_break(...) controls whether graph breaks error, but fullgraph=True takes precedence. [3]

E) Escape hatch for “I know this is safe to trace non-strictly”

Use torch._dynamo.nonstrict_trace when a function graph-breaks but is still traceable under non-strict rules. [5]


How to identify what’s happening

  • Use TORCH_LOGS="graph_breaks" to see where and why graph breaks occur; TORCH_LOGS="recompiles" to diagnose recompilation churn. [1]
  • Use tlparse for a higher-level compilation report (graph breaks, recompiles, errors). [1]

Sources

[1] PyTorch docs — torch.compile Troubleshooting (graph breaks, suppress_errors, disable, recompiles, logging)
[2] PyTorch docs — Use fullgraph=True to Identify and Eliminate Graph Breaks
[3] PyTorch docs — Toggling error_on_graph_break
[4] PyTorch GitHub issue — example torch._dynamo.exc.BackendCompilerFailed + note about suppress_errors fallback
[5] PyTorch docs — Use torch._dynamo.nonstrict_trace


Fix misleading warning message and handle torch.compile edge cases safely.

The warning message is inaccurate—it claims torch.compile is unavailable when the check actually skips compilation due to non-CUDA devices, not missing PyTorch support. Additionally, while torch.compile with default fullgraph=False gracefully handles most graph breaks by falling back to eager execution, backend or dynamic compilation failures can still occur at runtime without warning.

🔧 Suggested fix (clearer log + safe fallback)
-        if hasattr(torch, "compile") and self.device_type == "cuda":
-            logger.info("Compiling DiT decoder...")
-            self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
-            logger.info("torch.compile successful")
-        else:
-            logger.warning("torch.compile is not available on this PyTorch version.")
+        if hasattr(torch, "compile") and self.device_type == "cuda":
+            logger.info("Compiling DiT decoder...")
+            try:
+                self.model.decoder = torch.compile(self.model.decoder, mode="default")  # 'default' is more stable for LoRA
+            except Exception as exc:
+                logger.warning(f"torch.compile failed; continuing without compilation: {exc}")
+            else:
+                logger.info("torch.compile successful")
+        else:
+            logger.info("Skipping torch.compile (requires CUDA and torch.compile support).")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Added Torch Compile Logic
if hasattr(torch, "compile") and self.device_type == "cuda":
logger.info("Compiling DiT decoder...")
self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
logger.info("torch.compile successful")
else:
logger.warning("torch.compile is not available on this PyTorch version.")
# Added Torch Compile Logic
if hasattr(torch, "compile") and self.device_type == "cuda":
logger.info("Compiling DiT decoder...")
try:
self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
except Exception as exc:
logger.warning(f"torch.compile failed; continuing without compilation: {exc}")
else:
logger.info("torch.compile successful")
else:
logger.info("Skipping torch.compile (requires CUDA and torch.compile support).")
🤖 Prompt for AI Agents
In `@acestep/training/trainer.py` around lines 152 - 158, The current logic logs a
misleading warning when compilation is skipped for non-CUDA devices and doesn't
handle runtime failures from torch.compile; change the flow to first check for
torch.compile availability and second check device (self.device_type == "cuda"),
log distinct messages for "torch.compile not available" vs "skipping compile on
non-CUDA device", and wrap the call to torch.compile(self.model.decoder,
mode="default") in a try/except that catches exceptions, logs the exception via
logger.error with context, and restores the original self.model.decoder (or
leaves it unchanged) as a safe fallback so runtime compile failures won't break
training; reference torch.compile, self.device_type, self.model.decoder, and
logger in your changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant