Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions acestep/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,27 @@ def __init__(

# Inject LoRA into the decoder only
if check_peft_available():
# Fix: Force tensors out of inference mode before injection
for param in model.parameters():
param.data = param.data.clone()
if param.is_inference():
with torch.no_grad():
param.data = param.data.clone()

self.model, self.lora_info = inject_lora_into_dit(model, lora_config)
logger.info(f"LoRA injected: {self.lora_info['trainable_params']:,} trainable params")
else:
self.model = model
self.lora_info = {}
logger.warning("PEFT not available, training without LoRA adapters")

# Added Torch Compile Logic
if hasattr(torch, "compile") and self.device_type == "cuda":
logger.info("Compiling DiT decoder...")
self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
logger.info("torch.compile successful")
else:
logger.warning("torch.compile is not available on this PyTorch version.")
Comment on lines +152 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Read the specific file and lines mentioned in the review
sed -n '140,165p' acestep/training/trainer.py

Repository: ace-step/ACE-Step-1.5

Length of output: 1198


🌐 Web query:

torch.compile failure modes exceptions runtime errors fallback pattern

💡 Result:

Common torch.compile failure modes (what you’ll see)

  1. Graph breaks (not always an “error”)
  • With default fullgraph=False, Dynamo will graph-break, run the unsupported part in eager Python, then resume tracing (often just causing slowness, not failure). Typical causes: data-dependent control flow, many Python builtins/C functions, printing/logging, etc. [1]
  • If you force fullgraph=True, then any graph break becomes a hard failure (you’ll get an error instead of falling back to eager for that region). [2]
  1. Dynamo “unsupported” / tracing-time exceptions
  • You’ll commonly see torch._dynamo.exc.Unsupported: ... when Dynamo can’t trace something (examples include .item()-driven data dependence, dynamic slicing with Tensor arguments, unsupported Python constructs, skipped/marked functions, etc.). [1][3]
  1. Backend compiler failures (Inductor/Triton/etc.)
  • Tracing may succeed, but the backend codegen/compile fails and Dynamo raises torch._dynamo.exc.BackendCompilerFailed wrapping the real underlying error (e.g., toolchain/Triton compilation problems). [4]
  1. Recompilations / guard failures (perf pathology)
  • Not necessarily a crash, but you can get lots of recompiles when guards fail (often due to shape/stride/dtype changes). This can look like “it works but is extremely slow / keeps recompiling.” [1]

“Fail fast” vs “fallback” knobs (practical patterns)

A) Prefer partial compilation (default) to avoid hard failures

opt = torch.compile(model, fullgraph=False)  # default behavior: graph breaks allowed

Graph breaks will run those regions eagerly instead of erroring. [1][2]

B) Explicitly exclude known-problem regions (most robust)

Use @torch.compiler.disable to force eager execution for specific functions (and keep the rest compiled). [1]

import torch

`@torch.compiler.disable`
def troublesome_bits(x):
    ...

`@torch.compile`
def f(x):
    x = ...
    x = troublesome_bits(x)  # graph break + eager for this call
    return x

C) Last-resort “don’t crash, just run eager”

If you hit compiler crashes and want to keep training/inference running, set:

import torch._dynamo
torch._dynamo.config.suppress_errors = True

This suppresses compiler exceptions and falls back to eager for the failing compile attempt (not considered best practice long-term). [1][4]

D) Force errors on graph breaks (good for debugging / CI)

  • fullgraph=True (strongest) errors on any graph break. [2]
  • torch._dynamo.error_on_graph_break(...) controls whether graph breaks error, but fullgraph=True takes precedence. [3]

E) Escape hatch for “I know this is safe to trace non-strictly”

Use torch._dynamo.nonstrict_trace when a function graph-breaks but is still traceable under non-strict rules. [5]


How to identify what’s happening

  • Use TORCH_LOGS="graph_breaks" to see where and why graph breaks occur; TORCH_LOGS="recompiles" to diagnose recompilation churn. [1]
  • Use tlparse for a higher-level compilation report (graph breaks, recompiles, errors). [1]

Sources

[1] PyTorch docs — torch.compile Troubleshooting (graph breaks, suppress_errors, disable, recompiles, logging)
[2] PyTorch docs — Use fullgraph=True to Identify and Eliminate Graph Breaks
[3] PyTorch docs — Toggling error_on_graph_break
[4] PyTorch GitHub issue — example torch._dynamo.exc.BackendCompilerFailed + note about suppress_errors fallback
[5] PyTorch docs — Use torch._dynamo.nonstrict_trace


Fix misleading warning message and handle torch.compile edge cases safely.

The warning message is inaccurate—it claims torch.compile is unavailable when the check actually skips compilation due to non-CUDA devices, not missing PyTorch support. Additionally, while torch.compile with default fullgraph=False gracefully handles most graph breaks by falling back to eager execution, backend or dynamic compilation failures can still occur at runtime without warning.

🔧 Suggested fix (clearer log + safe fallback)
-        if hasattr(torch, "compile") and self.device_type == "cuda":
-            logger.info("Compiling DiT decoder...")
-            self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
-            logger.info("torch.compile successful")
-        else:
-            logger.warning("torch.compile is not available on this PyTorch version.")
+        if hasattr(torch, "compile") and self.device_type == "cuda":
+            logger.info("Compiling DiT decoder...")
+            try:
+                self.model.decoder = torch.compile(self.model.decoder, mode="default")  # 'default' is more stable for LoRA
+            except Exception as exc:
+                logger.warning(f"torch.compile failed; continuing without compilation: {exc}")
+            else:
+                logger.info("torch.compile successful")
+        else:
+            logger.info("Skipping torch.compile (requires CUDA and torch.compile support).")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Added Torch Compile Logic
if hasattr(torch, "compile") and self.device_type == "cuda":
logger.info("Compiling DiT decoder...")
self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
logger.info("torch.compile successful")
else:
logger.warning("torch.compile is not available on this PyTorch version.")
# Added Torch Compile Logic
if hasattr(torch, "compile") and self.device_type == "cuda":
logger.info("Compiling DiT decoder...")
try:
self.model.decoder = torch.compile(self.model.decoder, mode="default") # 'default' is more stable for LoRA
except Exception as exc:
logger.warning(f"torch.compile failed; continuing without compilation: {exc}")
else:
logger.info("torch.compile successful")
else:
logger.info("Skipping torch.compile (requires CUDA and torch.compile support).")
🤖 Prompt for AI Agents
In `@acestep/training/trainer.py` around lines 152 - 158, The current logic logs a
misleading warning when compilation is skipped for non-CUDA devices and doesn't
handle runtime failures from torch.compile; change the flow to first check for
torch.compile availability and second check device (self.device_type == "cuda"),
log distinct messages for "torch.compile not available" vs "skipping compile on
non-CUDA device", and wrap the call to torch.compile(self.model.decoder,
mode="default") in a try/except that catches exceptions, logs the exception via
logger.error with context, and restores the original self.model.decoder (or
leaves it unchanged) as a safe fallback so runtime compile failures won't break
training; reference torch.compile, self.device_type, self.model.decoder, and
logger in your changes.


# Model config for flow matching
self.config = model.config
Expand Down