Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion acestep/training/lora_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Optional, List, Dict, Any, Tuple
from loguru import logger
import types

import torch
import torch.nn as nn
Expand Down Expand Up @@ -95,8 +96,61 @@ def inject_lora_into_dit(
if not PEFT_AVAILABLE:
raise ImportError("PEFT library is required for LoRA training. Install with: pip install peft")

# Get the decoder (DiT model)
# Get the decoder (DiT model). Previous failed training runs may leave
# Fabric/PEFT wrappers attached; unwrap to a clean base module first.
decoder = model.decoder
while hasattr(decoder, "_forward_module"):
decoder = decoder._forward_module
if hasattr(decoder, "base_model"):
base_model = decoder.base_model
if hasattr(base_model, "model"):
decoder = base_model.model
else:
decoder = base_model
if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
decoder = decoder.model
model.decoder = decoder

# PEFT may call enable_input_require_grads() when is_gradient_checkpointing
# is true. AceStepDiTModel doesn't implement get_input_embeddings, so the
# default implementation raises NotImplementedError. Guard this path.
if hasattr(decoder, "enable_input_require_grads"):
orig_enable_input_require_grads = decoder.enable_input_require_grads

def _safe_enable_input_require_grads(self):
try:
result = orig_enable_input_require_grads()
try:
self._acestep_input_grads_hook_enabled = True
except Exception:
pass
return result
except NotImplementedError:
try:
self._acestep_input_grads_hook_enabled = False
except Exception:
pass
if not getattr(self, "_acestep_input_grads_warning_emitted", False):
logger.info(
"Skipping enable_input_require_grads for decoder: "
"get_input_embeddings is not implemented (expected for DiT)"
)
try:
self._acestep_input_grads_warning_emitted = True
except Exception:
pass
return None

decoder.enable_input_require_grads = types.MethodType(
_safe_enable_input_require_grads, decoder
)

# Avoid PEFT auto-prep path on non-embedding diffusion decoder.
if hasattr(decoder, "is_gradient_checkpointing"):
try:
decoder.is_gradient_checkpointing = False
except Exception:
pass

# Create PEFT LoRA config
peft_lora_config = LoraConfig(
Expand Down
209 changes: 193 additions & 16 deletions acestep/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,130 @@ def _select_fabric_precision(device_type: str) -> str:
if device_type in ("cuda", "xpu"):
return "bf16-mixed"
if device_type == "mps":
# Use AMP on MPS for better throughput. Trainable LoRA parameters are
# explicitly forced to fp32 before optimizer/Fabric setup.
return "16-mixed"
return "32-true"


def _ensure_trainable_params_fp32(module: nn.Module) -> Tuple[int, int]:
"""Force trainable floating-point parameters to fp32."""
casted = 0
total = 0
for p in module.parameters():
if not p.requires_grad:
continue
total += 1
if p.is_floating_point() and p.dtype != torch.float32:
with torch.no_grad():
p.data = p.data.float()
casted += 1
return casted, total


def _count_nonfinite_grads(params: List[torch.nn.Parameter]) -> Tuple[int, int]:
"""Count non-finite gradient tensors among params with gradients."""
nonfinite = 0
total_with_grad = 0
for p in params:
g = p.grad
if g is None:
continue
total_with_grad += 1
if not torch.isfinite(g).all():
nonfinite += 1
return nonfinite, total_with_grad


def _ensure_optimizer_params_fp32(optimizer: torch.optim.Optimizer) -> Tuple[int, int]:
"""Force optimizer parameter tensors to fp32 when trainable."""
casted = 0
total = 0
for group in optimizer.param_groups:
for p in group.get("params", []):
if p is None:
continue
total += 1
if p.is_floating_point() and p.dtype != torch.float32:
with torch.no_grad():
p.data = p.data.float()
casted += 1
return casted, total


def _iter_module_wrappers(module: nn.Module) -> List[nn.Module]:
"""Collect wrapper chain modules (Fabric/PEFT/compile/base-model wrappers)."""
modules: List[nn.Module] = []
stack = [module]
visited = set()

while stack:
current = stack.pop()
if not isinstance(current, nn.Module):
continue
module_id = id(current)
if module_id in visited:
continue
visited.add(module_id)
modules.append(current)

for attr_name in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
child = getattr(current, attr_name, None)
if isinstance(child, nn.Module):
stack.append(child)

return modules


def _configure_training_memory_features(decoder: nn.Module) -> Tuple[bool, bool, bool]:
"""
Enable gradient checkpointing and disable use_cache across wrapped decoder modules.

Returns:
Tuple[checkpointing_enabled, cache_disabled, input_grads_enabled]
"""
checkpointing_enabled = False
cache_disabled = False
input_grads_enabled = False

for mod in _iter_module_wrappers(decoder):
if hasattr(mod, "gradient_checkpointing_enable"):
try:
mod.gradient_checkpointing_enable()
checkpointing_enabled = True
except Exception:
pass
elif hasattr(mod, "gradient_checkpointing"):
try:
mod.gradient_checkpointing = True
checkpointing_enabled = True
except Exception:
pass

# PEFT + gradient checkpointing can require input embeddings to have
# gradients enabled, otherwise loss may be detached (no grad_fn).
if hasattr(mod, "enable_input_require_grads"):
try:
mod.enable_input_require_grads()
hook_enabled = bool(getattr(mod, "_acestep_input_grads_hook_enabled", False))
has_require_hook = getattr(mod, "_require_grads_hook", None) is not None
if hook_enabled or has_require_hook:
input_grads_enabled = True
except Exception:
pass

cfg = getattr(mod, "config", None)
if cfg is not None and hasattr(cfg, "use_cache"):
try:
if getattr(cfg, "use_cache", None) is not False:
cfg.use_cache = False
cache_disabled = True
except Exception:
pass

return checkpointing_enabled, cache_disabled, input_grads_enabled


def sample_discrete_timestep(bsz, timesteps_tensor):
"""Sample timesteps from discrete turbo shift=3 schedule.

Expand Down Expand Up @@ -132,6 +252,10 @@ def __init__(
self.dtype = _select_compute_dtype(self.device_type)
self.transfer_non_blocking = self.device_type in ("cuda", "xpu")
self.timesteps_tensor = torch.tensor(TURBO_SHIFT3_TIMESTEPS, device=self.device, dtype=self.dtype)
# When gradient checkpointing is enabled via wrapper layers that don't expose
# enable_input_require_grads(), force at least one forward input to require grad
# so checkpointed segments keep a valid autograd graph.
self.force_input_grads_for_checkpointing = False

# Inject LoRA into the decoder only
if check_peft_available():
Expand Down Expand Up @@ -219,6 +343,8 @@ def training_step(

# Interpolate: x_t = t * x1 + (1 - t) * x0
xt = t_ * x1 + (1.0 - t_) * x0
if self.force_input_grads_for_checkpointing:
xt = xt.requires_grad_(True)

# Forward through decoder (distilled turbo model, no CFG)
decoder_outputs = self.model.decoder(
Expand Down Expand Up @@ -327,6 +453,15 @@ def train_from_preprocessed(
device=self.dit_handler.device,
dtype=self.dit_handler.dtype,
)
ckpt_enabled, cache_disabled, input_grads_enabled = _configure_training_memory_features(self.module.model.decoder)
# DiT decoder does not expose token embeddings like causal LMs.
# Force grad-carrying inputs for checkpointed segments to avoid
# detached losses regardless of wrapper hook availability.
self.module.force_input_grads_for_checkpointing = ckpt_enabled
logger.info(
f"Training memory features: gradient_checkpointing={ckpt_enabled}, "
f"use_cache_disabled={cache_disabled}, input_grads_enabled={input_grads_enabled}"
)

# Create data module
data_module = PreprocessedDataModule(
Expand All @@ -348,6 +483,12 @@ def train_from_preprocessed(
return

yield 0, 0.0, f"📂 Loaded {len(data_module.train_dataset)} preprocessed samples"
if ckpt_enabled:
yield 0, 0.0, "🧠 Gradient checkpointing enabled for decoder"
else:
yield 0, 0.0, "⚠️ Gradient checkpointing not enabled (model wrapper did not expose it)"
if not input_grads_enabled:
yield 0, 0.0, "ℹ️ Input-grad hook not available on this DiT; using explicit checkpointing fallback"

if LIGHTNING_AVAILABLE:
yield from self._train_with_fabric(data_module, training_state, resume_from)
Expand Down Expand Up @@ -397,7 +538,19 @@ def _train_with_fabric(

yield 0, 0.0, f"🚀 Starting training (device: {device_type}, precision: {precision})..."

# Get dataloaders
# Keep decoder weights in a stable dtype before optimizer/Fabric setup.
# MPS stays in fp32 weights for stability; computation still uses fp16
# autocast inside training_step.
if device_type == "mps" or precision.endswith("-mixed"):
self.module.model.decoder = self.module.model.decoder.to(dtype=torch.float32)
else:
self.module.model.decoder = self.module.model.decoder.to(dtype=self.module.dtype)
casted_trainable, total_trainable_tensors = _ensure_trainable_params_fp32(self.module.model.decoder)
logger.info(
f"Trainable tensor dtype fixup: casted {casted_trainable}/{total_trainable_tensors} to fp32"
)

# Get dataloader
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader() if hasattr(data_module, "val_dataloader") else None

Expand Down Expand Up @@ -456,11 +609,12 @@ def _train_with_fabric(
milestones=[warmup_steps],
)

# Convert model to the selected compute dtype for consistent execution.
self.module.model = self.module.model.to(self.module.dtype)

# Setup with Fabric - only the decoder (which has LoRA)
self.module.model.decoder, optimizer = self.fabric.setup(self.module.model.decoder, optimizer)
casted_opt_params, total_opt_params = _ensure_optimizer_params_fp32(optimizer)
logger.info(
f"Optimizer param dtype fixup: casted {casted_opt_params}/{total_opt_params} to fp32"
)
train_loader = self.fabric.setup_dataloaders(train_loader)

# Handle resume from checkpoint (load AFTER Fabric setup)
Expand Down Expand Up @@ -552,10 +706,22 @@ def _train_with_fabric(

# Optimizer step
if accumulation_step >= self.training_config.gradient_accumulation_steps:
nonfinite_grads, grad_tensors = _count_nonfinite_grads(trainable_params)
if nonfinite_grads > 0:
optimizer.zero_grad(set_to_none=True)
yield global_step, float("nan"), (
f"⚠️ Non-finite gradients ({nonfinite_grads}/{grad_tensors}); "
"skipping optimizer step"
)
accumulated_loss = 0.0
accumulation_step = 0
continue

self.fabric.clip_gradients(
self.module.model.decoder,
optimizer,
max_norm=self.training_config.max_grad_norm,
error_if_nonfinite=False,
)

optimizer.step()
Expand Down Expand Up @@ -587,15 +753,26 @@ def _train_with_fabric(
# Flush remainder to avoid dropping gradients when epoch length is not
# divisible by gradient_accumulation_steps.
if accumulation_step > 0:
self.fabric.clip_gradients(
self.module.model.decoder,
optimizer,
max_norm=self.training_config.max_grad_norm,
)
nonfinite_grads, grad_tensors = _count_nonfinite_grads(trainable_params)
if nonfinite_grads > 0:
optimizer.zero_grad(set_to_none=True)
yield global_step, float("nan"), (
f"⚠️ Non-finite gradients ({nonfinite_grads}/{grad_tensors}); "
"skipping optimizer remainder step"
)
accumulated_loss = 0.0
accumulation_step = 0
else:
self.fabric.clip_gradients(
self.module.model.decoder,
optimizer,
max_norm=self.training_config.max_grad_norm,
error_if_nonfinite=False,
)

optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)

global_step += 1
avg_loss = accumulated_loss / accumulation_step
Expand All @@ -612,10 +789,10 @@ def _train_with_fabric(
self.fabric.log("train/lr", scheduler.get_last_lr()[0], step=global_step)
yield global_step, avg_loss, f"Epoch {epoch+1}/{self.training_config.max_epochs}, Step {global_step}, Loss: {avg_loss:.4f}"

epoch_loss += avg_loss
num_updates += 1
accumulated_loss = 0.0
accumulation_step = 0
epoch_loss += avg_loss
num_updates += 1
accumulated_loss = 0.0
accumulation_step = 0

# End of epoch
epoch_time = time.time() - epoch_start_time
Expand Down