-
Notifications
You must be signed in to change notification settings - Fork 546
New Feature: LoKr Training #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…2 trainable params, and non-finite grad guards
…addition to LoRA.
📝 WalkthroughWalkthroughAdds LyCORIS-based LoKR adapter support across UI, event wiring, training backend, configs, utilities, handler adapter management, and dependency updates to enable LoKR training, checkpointing, saving, loading, and export alongside existing LoRA flows. Changes
Sequence DiagramsequenceDiagram
participant UI as UI (Gradio)
participant Events as EventHandler
participant Trainer as LoKRTrainer
participant Model as DiT Model
participant LyCORIS as LyCORIS Net
participant Disk as Disk/Storage
UI->>Events: start_lokr_training(tensor_dir, lokr_config, params)
Events->>Events: validate inputs, update training_state
Events->>Trainer: init(dit_handler, LoKRConfig, TrainingConfig)
Trainer->>Model: unwrap decoder wrappers
Trainer->>LyCORIS: inject_lokr_into_dit(model, lokr_config)
LyCORIS->>Model: attach LoKR adapters, freeze base params
Trainer->>Trainer: setup data module & memory features
loop epochs
Trainer->>Disk: load preprocessed tensors
loop batches
Trainer->>Model: forward (with LoKR)
Model->>LyCORIS: apply LoKR adaptations
Trainer->>Trainer: backward, accumulate, clip
Trainer->>Events: yield progress, loss, logs
end
Trainer->>Disk: save_lokr_training_checkpoint()
end
Trainer->>Disk: save_lokr_weights(output_dir)
Trainer->>Events: emit completion status
Events->>UI: update progress/log/export status
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
acestep/gradio_ui/events/__init__.py (1)
1196-1336:⚠️ Potential issue | 🟡 MinorRemove redundant
loggerimports inside both wrapper functions.The module already imports
loggerat the top level (line 7); the inner imports in bothtraining_wrapper(line 1204) andlokr_training_wrapper(line 1262) redefine it unnecessarily and trigger F811 violations.🛠️ Suggested fixes
# Start training from preprocessed tensors def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts): - from loguru import logger if not isinstance(ts, dict):# Start LoKR training from preprocessed tensors def lokr_training_wrapper( tensor_dir, ... ): - from loguru import logger if not isinstance(ts, dict):
🤖 Fix all issues with AI agents
In `@acestep/training/trainer.py`:
- Around line 1321-1406: The _train_basic function lacks checks for non-finite
gradients before stepping; add the same guard used in _train_with_fabric: after
torch.nn.utils.clip_grad_norm_(trainable_params,
self.training_config.max_grad_norm) inspect gradients for non-finite values
(e.g., any torch.isnan or torch.isinf across p.grad for p in trainable_params)
and if non-finite, call optimizer.zero_grad(set_to_none=True), skip
optimizer.step() and scheduler.step(), emit/ yield a log message about skipping
the step, and avoid updating global_step/accumulated counters; apply this change
in both the inner accumulation flush (where accumulation_step >=
gradient_accumulation_steps) and the final remainder block (if accumulation_step
> 0), referencing _train_basic, trainable_params, optimizer, scheduler,
accumulated_loss, accumulation_step, and global_step.
🧹 Nitpick comments (12)
requirements.txt (1)
34-39: Pinlycoris-lorato a tested version in both requirements.txt and pyproject.toml.Unpinned training dependencies can drift and break reproducibility. Use
lycoris-lora>=3.4.0(current stable) or a tighter constraint once you confirm a working version, keeping both configuration files in sync.acestep/training/__init__.py (2)
1-6: Update module docstring to reflect LoKR support.The module docstring only mentions LoRA training, but this module now also exposes LoKR training functionality. Consider updating for accuracy.
📝 Suggested docstring update
""" ACE-Step Training Module -This module provides LoRA training functionality for ACE-Step models, -including dataset building, audio labeling, and training utilities. +This module provides LoRA and LoKr adapter training functionality for ACE-Step +models, including dataset building, audio labeling, and training utilities. """
54-63: LoKR utilities grouped under LoRA comment section.The LoKR utility exports (lines 60-63) are placed under the
# LoRA Utilscomment. Consider adding a separate comment section for clarity.📝 Suggested organization
# LoRA Utils "inject_lora_into_dit", "save_lora_weights", "load_lora_weights", "merge_lora_weights", "check_peft_available", + # LoKR Utils "inject_lokr_into_dit", "save_lokr_weights", "load_lokr_weights", "check_lycoris_available",acestep/gradio_ui/interfaces/training.py (2)
653-660: LoKr tab missing resume checkpoint support.The LoRA tab includes a resume checkpoint field (lines 465-469), but LoKr training has no equivalent. If this is intentional for the initial release, consider adding a TODO or noting it as a future enhancement.
📝 Suggested addition for feature parity
with gr.Row(): lokr_output_dir = gr.Textbox( label="Output Directory", value="./lokr_output", placeholder="./lokr_output", info="Where LoKr checkpoints and final weights will be written.", ) + + with gr.Row(): + lokr_resume_checkpoint_dir = gr.Textbox( + label="Resume Checkpoint (optional)", + placeholder="./lokr_output/checkpoints/epoch_200", + info="Directory of a saved LoKr checkpoint to resume from", + )
661-698: LoKr tab missing export functionality.The LoRA tab has an export section (lines 510-523) allowing users to export trained weights. The LoKr tab lacks this feature. If this is intentional for the initial release, consider documenting this limitation or adding it for parity.
acestep/gradio_ui/events/training_handlers.py (4)
801-802: Unusedprogressparameter.The
progressparameter is declared but never used instart_lokr_training. Either remove it or wire it up for progress callbacks similar to other functions.📝 Option 1: Remove unused parameter
lokr_output_dir: str, training_state: Dict, - progress=None, ):
829-833: Remove unused noqa directive.The
# noqa: F401comment is unnecessary sinceFabricis used implicitly to verify the import succeeded.📝 Suggested fix
try: - from lightning.fabric import Fabric # noqa: F401 + from lightning.fabric import Fabric except ImportError as e:
817-827: LoKr missing low VRAM tier warnings.The LoRA training path (lines 556-572) includes warnings for CPU-only and low VRAM tier scenarios. The LoKr path lacks these warnings, which could lead to users being unaware of suboptimal training conditions.
📝 Add VRAM tier warnings for LoKr
if getattr(dit_handler, "quantization", None) is not None: + gpu_config = get_global_gpu_config() + if gpu_config.gpu_memory_gb <= 0: + yield ( + "WARNING: CPU-only training detected. Using best-effort training path " + "(non-quantized DiT). Performance will be sub-optimal.", + "", + None, + training_state, + ) + elif gpu_config.tier in {"tier1", "tier2", "tier3", "tier4"}: + yield ( + f"WARNING: Low VRAM tier detected ({gpu_config.gpu_memory_gb:.1f} GB, {gpu_config.tier}). " + "Using best-effort training path (non-quantized DiT). Performance may be sub-optimal.", + "", + None, + training_state, + ) + yield "Switching model to training preset (disable quantization)...", "", None, training_state
782-981: Significant code duplication between LoRA and LoKr handlers.The
start_lokr_trainingfunction (~200 lines) is largely a copy ofstart_trainingwith minor differences (config class, trainer class, messages). Consider extracting shared logic into a helper function to improve maintainability.This is acceptable for the initial implementation but should be considered for future refactoring.
acestep/training/trainer.py (3)
168-214: Silent exception swallowing in_configure_training_memory_features.Multiple
try-except-passblocks silently swallow exceptions (lines 184-185, 190-191, 202-203, 211-212). While this is intentional for graceful degradation when features aren't available, consider logging at debug level to aid troubleshooting.📝 Suggested logging improvement
if hasattr(mod, "gradient_checkpointing_enable"): try: mod.gradient_checkpointing_enable() checkpointing_enabled = True except Exception: - pass + logger.debug(f"gradient_checkpointing_enable failed on {type(mod).__name__}")Apply similar pattern to other except blocks.
914-921: Unuseddtypeparameter inPreprocessedLoKRModule.__init__.The
dtypeparameter is declared but unused since compute dtype is derived from device type via_select_compute_dtype. Either remove the parameter or use it to allow explicit dtype override.📝 Option 1: Remove unused parameter
def __init__( self, model: nn.Module, lokr_config: LoKRConfig, training_config: TrainingConfig, device: torch.device, - dtype: torch.dtype, ):Note: This would require updating the caller in
LoKRTrainer.train_from_preprocessed(lines 1052-1058).
911-994: Significant code duplication between LoRA and LoKR implementations.
PreprocessedLoKRModuleis nearly identical toPreprocessedLoRAModule, andLoKRTraineris nearly identical toLoRATrainer. Consider extracting a base class or shared helper functions to reduce duplication and ease future maintenance.This is acceptable for the initial implementation but should be considered for future refactoring to prevent the classes from diverging in subtle ways.
Also applies to: 997-1435
| def _train_basic( | ||
| self, | ||
| data_module: PreprocessedDataModule, | ||
| training_state: Optional[Dict], | ||
| ) -> Generator[Tuple[int, float, str], None, None]: | ||
| yield 0, 0.0, "🚀 Starting basic training loop..." | ||
| os.makedirs(self.training_config.output_dir, exist_ok=True) | ||
|
|
||
| train_loader = data_module.train_dataloader() | ||
| trainable_params = [p for p in self.module.model.parameters() if p.requires_grad] | ||
| if not trainable_params: | ||
| yield 0, 0.0, "❌ No trainable parameters found!" | ||
| return | ||
|
|
||
| optimizer = AdamW( | ||
| trainable_params, | ||
| lr=self.training_config.learning_rate, | ||
| weight_decay=self.training_config.weight_decay, | ||
| ) | ||
| steps_per_epoch = max(1, math.ceil(len(train_loader) / self.training_config.gradient_accumulation_steps)) | ||
| total_steps = steps_per_epoch * self.training_config.max_epochs | ||
| warmup_steps = min(self.training_config.warmup_steps, max(1, total_steps // 10)) | ||
|
|
||
| warmup_scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps) | ||
| main_scheduler = CosineAnnealingWarmRestarts( | ||
| optimizer, | ||
| T_0=max(1, total_steps - warmup_steps), | ||
| T_mult=1, | ||
| eta_min=self.training_config.learning_rate * 0.01, | ||
| ) | ||
| scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_steps]) | ||
|
|
||
| global_step = 0 | ||
| accumulation_step = 0 | ||
| accumulated_loss = 0.0 | ||
| optimizer.zero_grad(set_to_none=True) | ||
| self.module.model.decoder.train() | ||
|
|
||
| for epoch in range(self.training_config.max_epochs): | ||
| epoch_loss = 0.0 | ||
| num_updates = 0 | ||
| epoch_start_time = time.time() | ||
|
|
||
| for batch in train_loader: | ||
| if training_state and training_state.get("should_stop", False): | ||
| yield global_step, accumulated_loss / max(accumulation_step, 1), "⏹️ Training stopped" | ||
| return | ||
|
|
||
| loss = self.module.training_step(batch) | ||
| loss = loss / self.training_config.gradient_accumulation_steps | ||
| loss.backward() | ||
| accumulated_loss += loss.item() | ||
| accumulation_step += 1 | ||
|
|
||
| if accumulation_step >= self.training_config.gradient_accumulation_steps: | ||
| torch.nn.utils.clip_grad_norm_(trainable_params, self.training_config.max_grad_norm) | ||
| optimizer.step() | ||
| scheduler.step() | ||
| optimizer.zero_grad(set_to_none=True) | ||
| global_step += 1 | ||
|
|
||
| avg_loss = accumulated_loss / accumulation_step | ||
| if global_step % self.training_config.log_every_n_steps == 0: | ||
| yield global_step, avg_loss, f"Epoch {epoch+1}, Step {global_step}, Loss: {avg_loss:.4f}" | ||
|
|
||
| epoch_loss += avg_loss | ||
| num_updates += 1 | ||
| accumulated_loss = 0.0 | ||
| accumulation_step = 0 | ||
|
|
||
| if accumulation_step > 0: | ||
| torch.nn.utils.clip_grad_norm_(trainable_params, self.training_config.max_grad_norm) | ||
| optimizer.step() | ||
| scheduler.step() | ||
| optimizer.zero_grad(set_to_none=True) | ||
| global_step += 1 | ||
|
|
||
| avg_loss = accumulated_loss / accumulation_step | ||
| if global_step % self.training_config.log_every_n_steps == 0: | ||
| yield global_step, avg_loss, f"Epoch {epoch+1}, Step {global_step}, Loss: {avg_loss:.4f}" | ||
|
|
||
| epoch_loss += avg_loss | ||
| num_updates += 1 | ||
| accumulated_loss = 0.0 | ||
| accumulation_step = 0 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoKRTrainer._train_basic missing non-finite gradient handling.
The _train_with_fabric method (lines 1216-1226, 1255-1263) includes non-finite gradient detection and skipping, but _train_basic (lines 1375-1376, 1391-1392) only uses clip_grad_norm_ without the guard. This could lead to NaN propagation when training without Fabric.
🐛 Suggested fix to add gradient guards
if accumulation_step >= self.training_config.gradient_accumulation_steps:
+ nonfinite_grads, grad_tensors = _count_nonfinite_grads(trainable_params)
+ if nonfinite_grads > 0:
+ optimizer.zero_grad(set_to_none=True)
+ yield global_step, float("nan"), (
+ f"⚠️ Non-finite gradients ({nonfinite_grads}/{grad_tensors}); "
+ "skipping optimizer step"
+ )
+ accumulated_loss = 0.0
+ accumulation_step = 0
+ continue
+
torch.nn.utils.clip_grad_norm_(trainable_params, self.training_config.max_grad_norm)Apply similar pattern to the epoch remainder block (lines 1391-1405).
🤖 Prompt for AI Agents
In `@acestep/training/trainer.py` around lines 1321 - 1406, The _train_basic
function lacks checks for non-finite gradients before stepping; add the same
guard used in _train_with_fabric: after
torch.nn.utils.clip_grad_norm_(trainable_params,
self.training_config.max_grad_norm) inspect gradients for non-finite values
(e.g., any torch.isnan or torch.isinf across p.grad for p in trainable_params)
and if non-finite, call optimizer.zero_grad(set_to_none=True), skip
optimizer.step() and scheduler.step(), emit/ yield a log message about skipping
the step, and avoid updating global_step/accumulated counters; apply this change
in both the inner accumulation flush (where accumulation_step >=
gradient_accumulation_steps) and the final remainder block (if accumulation_step
> 0), referencing _train_basic, trainable_params, optimizer, scheduler,
accumulated_loss, accumulation_step, and global_step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@acestep/gradio_ui/events/__init__.py`:
- Around line 1243-1289: In lokr_training_wrapper remove the local "from loguru
import logger" import that shadows the module-level logger; use the module-level
logger (the one imported at top of the file) when calling logger.exception in
the except block and keep the rest of the function (train_h.start_lokr_training
loop, dit_handler usage, and ts handling) unchanged.
In `@acestep/gradio_ui/events/training_handlers.py`:
- Line 830: Remove the unnecessary linter suppression on the import of Fabric —
locate the line importing Fabric (from lightning.fabric import Fabric) in
training_handlers.py and delete the trailing "# noqa: F401" so the import line
is clean (leave the import itself intact if used elsewhere; only remove the
unused noqa directive).
- Around line 946-957: The LoKr training branch updates UI state but doesn't
write to the terminal; replicate the LoRA loop's terminal logging by
constructing a log message (e.g., combine status_text, time_info and/or log_text
into a single log_msg) and call logger.info(log_msg) before yielding
display_status, log_text, loss_data, training_state; ensure you use the same
logger identifier (logger) and that log_msg mirrors the content used in the LoRA
handler so terminal logs are consistent with UI updates.
🧹 Nitpick comments (6)
acestep/gradio_ui/events/__init__.py (1)
1291-1318: LoKr training lacks resume checkpoint support.Unlike the LoRA training wiring (lines 1216-1240) which includes
resume_checkpoint_dir, the LoKr training button wiring doesn't support resuming from checkpoints. This may be intentional for the initial release, but worth noting as a potential follow-up enhancement.acestep/gradio_ui/events/training_handlers.py (2)
843-878: Consider extracting device configuration logic to reduce duplication.The device detection and dataloader configuration logic (lines 843-878) is nearly identical to the LoRA training function (lines 608-644). Extracting this to a shared helper would improve maintainability.
♻️ Example helper function
def _get_device_training_config(dit_handler) -> dict: """Get device-specific training configuration for DataLoader and mixed precision.""" device_attr = getattr(dit_handler, "device", "") if hasattr(device_attr, "type"): device_type = str(device_attr.type).lower() else: device_type = str(device_attr).split(":", 1)[0].lower() if device_type == "cuda": return { "num_workers": 4, "pin_memory": True, "prefetch_factor": 2, "persistent_workers": True, "pin_memory_device": "cuda", "mixed_precision": "bf16", } elif device_type == "xpu": return { "num_workers": 4, "pin_memory": True, "prefetch_factor": 2, "persistent_workers": True, "pin_memory_device": None, "mixed_precision": "bf16", } elif device_type == "mps": return { "num_workers": 0, "pin_memory": False, "prefetch_factor": 2, "persistent_workers": False, "pin_memory_device": None, "mixed_precision": "fp16", } else: return { "num_workers": 0, "pin_memory": False, "prefetch_factor": 2, "persistent_workers": False, "pin_memory_device": None, "mixed_precision": "fp32", }
817-827: LoKr training missing VRAM tier warnings present in LoRA training.The LoRA training path (lines 555-582) includes helpful warnings for CPU-only and low VRAM tier scenarios. The LoKr training path lacks these warnings, which could leave users without important context about expected performance.
acestep/gradio_ui/interfaces/training.py (3)
525-533: LoKr tab missing internationalization (i18n) support.The LoKr tab uses hardcoded English strings (e.g., "📊 Preprocessed Tensors", "Dataset Info") while the LoRA tab uses the
t()function for translations. This inconsistency means the LoKr tab won't be localized.Consider adding i18n keys for the LoKr tab labels and using
t("training.lokr_*")pattern to match the LoRA tab's approach. This can be addressed in a follow-up PR if localization is a priority.
605-611: LoKr epochs slider doesn't respect DEBUG_TRAINING flag.The LoRA tab uses
epoch_min,epoch_step, andepoch_defaultvariables (lines 29-31) which adjust based on theDEBUG_TRAININGflag. The LoKr epochs slider has hardcoded values (min=1, step=1, value=50), creating an inconsistency in debug behavior.♻️ Proposed fix for consistency
lokr_train_epochs = gr.Slider( - minimum=1, + minimum=epoch_min, maximum=4000, - step=1, - value=50, + step=epoch_step, + value=50 if debug_training_enabled else 50, # LoKr default stays at 50 label="Max Epochs", )Note: You may want a separate
lokr_epoch_defaultif the LoKr default should differ from LoRA in non-debug mode.
629-636: Save frequency may conflict with low default epoch count.With
lokr_train_epochsdefaulting to 50 andlokr_save_every_n_epochshaving minimum=50, users can only save checkpoints at the very end of training by default. This differs from LoRA where higher epoch counts allow multiple checkpoint saves.Consider either:
- Lowering the minimum for
lokr_save_every_n_epochs(e.g., to 10)- Adding a note in the UI info text about this behavior
| def lokr_training_wrapper( | ||
| tensor_dir, | ||
| linear_dim, | ||
| linear_alpha, | ||
| factor, | ||
| decompose_both, | ||
| use_tucker, | ||
| use_scalar, | ||
| weight_decompose, | ||
| lr, | ||
| ep, | ||
| bs, | ||
| ga, | ||
| se, | ||
| sh, | ||
| sd, | ||
| od, | ||
| ts, | ||
| ): | ||
| from loguru import logger | ||
| if not isinstance(ts, dict): | ||
| ts = {"is_training": False, "should_stop": False} | ||
| try: | ||
| for progress, log_msg, plot, state in train_h.start_lokr_training( | ||
| tensor_dir, | ||
| dit_handler, | ||
| linear_dim, | ||
| linear_alpha, | ||
| factor, | ||
| decompose_both, | ||
| use_tucker, | ||
| use_scalar, | ||
| weight_decompose, | ||
| lr, | ||
| ep, | ||
| bs, | ||
| ga, | ||
| se, | ||
| sh, | ||
| sd, | ||
| od, | ||
| ts, | ||
| ): | ||
| yield progress, log_msg, plot, state | ||
| except Exception as e: | ||
| logger.exception("LoKR training wrapper error") | ||
| yield f"❌ Error: {str(e)}", str(e), None, ts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove redundant logger import inside function.
The logger import at line 1262 shadows the module-level import from line 7. This is unnecessary and flagged by static analysis (F811).
🔧 Proposed fix
def lokr_training_wrapper(
tensor_dir,
linear_dim,
linear_alpha,
factor,
decompose_both,
use_tucker,
use_scalar,
weight_decompose,
lr,
ep,
bs,
ga,
se,
sh,
sd,
od,
ts,
):
- from loguru import logger
if not isinstance(ts, dict):
ts = {"is_training": False, "should_stop": False}
try:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def lokr_training_wrapper( | |
| tensor_dir, | |
| linear_dim, | |
| linear_alpha, | |
| factor, | |
| decompose_both, | |
| use_tucker, | |
| use_scalar, | |
| weight_decompose, | |
| lr, | |
| ep, | |
| bs, | |
| ga, | |
| se, | |
| sh, | |
| sd, | |
| od, | |
| ts, | |
| ): | |
| from loguru import logger | |
| if not isinstance(ts, dict): | |
| ts = {"is_training": False, "should_stop": False} | |
| try: | |
| for progress, log_msg, plot, state in train_h.start_lokr_training( | |
| tensor_dir, | |
| dit_handler, | |
| linear_dim, | |
| linear_alpha, | |
| factor, | |
| decompose_both, | |
| use_tucker, | |
| use_scalar, | |
| weight_decompose, | |
| lr, | |
| ep, | |
| bs, | |
| ga, | |
| se, | |
| sh, | |
| sd, | |
| od, | |
| ts, | |
| ): | |
| yield progress, log_msg, plot, state | |
| except Exception as e: | |
| logger.exception("LoKR training wrapper error") | |
| yield f"❌ Error: {str(e)}", str(e), None, ts | |
| def lokr_training_wrapper( | |
| tensor_dir, | |
| linear_dim, | |
| linear_alpha, | |
| factor, | |
| decompose_both, | |
| use_tucker, | |
| use_scalar, | |
| weight_decompose, | |
| lr, | |
| ep, | |
| bs, | |
| ga, | |
| se, | |
| sh, | |
| sd, | |
| od, | |
| ts, | |
| ): | |
| if not isinstance(ts, dict): | |
| ts = {"is_training": False, "should_stop": False} | |
| try: | |
| for progress, log_msg, plot, state in train_h.start_lokr_training( | |
| tensor_dir, | |
| dit_handler, | |
| linear_dim, | |
| linear_alpha, | |
| factor, | |
| decompose_both, | |
| use_tucker, | |
| use_scalar, | |
| weight_decompose, | |
| lr, | |
| ep, | |
| bs, | |
| ga, | |
| se, | |
| sh, | |
| sd, | |
| od, | |
| ts, | |
| ): | |
| yield progress, log_msg, plot, state | |
| except Exception as e: | |
| logger.exception("LoKR training wrapper error") | |
| yield f"❌ Error: {str(e)}", str(e), None, ts |
🧰 Tools
🪛 Ruff (0.14.14)
[error] 1262-1262: Redefinition of unused logger from line 7: logger redefined here
Remove definition: logger
(F811)
[warning] 1287-1287: Do not catch blind exception: Exception
(BLE001)
[warning] 1289-1289: Use explicit conversion flag
Replace with conversion flag
(RUF010)
🤖 Prompt for AI Agents
In `@acestep/gradio_ui/events/__init__.py` around lines 1243 - 1289, In
lokr_training_wrapper remove the local "from loguru import logger" import that
shadows the module-level logger; use the module-level logger (the one imported
at top of the file) when calling logger.exception in the except block and keep
the rest of the function (train_h.start_lokr_training loop, dit_handler usage,
and ts handling) unchanged.
| return | ||
|
|
||
| try: | ||
| from lightning.fabric import Fabric # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused noqa directive.
The # noqa: F401 comment is unnecessary since F401 isn't enabled in the linter configuration.
🔧 Proposed fix
try:
- from lightning.fabric import Fabric # noqa: F401
+ from lightning.fabric import Fabric
except ImportError as e:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from lightning.fabric import Fabric # noqa: F401 | |
| try: | |
| from lightning.fabric import Fabric | |
| except ImportError as e: |
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 830-830: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In `@acestep/gradio_ui/events/training_handlers.py` at line 830, Remove the
unnecessary linter suppression on the import of Fabric — locate the line
importing Fabric (from lightning.fabric import Fabric) in training_handlers.py
and delete the trailing "# noqa: F401" so the import line is clean (leave the
import itself intact if used elsewhere; only remove the unused noqa directive).
| display_status = f"{status_text}\n{time_info}" | ||
| log_lines.append(status_text) | ||
| if len(log_lines) > 15: | ||
| log_lines = log_lines[-15:] | ||
| log_text = "\n".join(log_lines) | ||
|
|
||
| if step > 0 and loss is not None and loss == loss: | ||
| step_list.append(step) | ||
| loss_list.append(float(loss)) | ||
| loss_data = pd.DataFrame({"step": step_list, "loss": loss_list}) | ||
|
|
||
| yield display_status, log_text, loss_data, training_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoKr training loop missing terminal logging.
Unlike the LoRA training loop (line 724) which logs to the terminal via logger.info(log_msg), the LoKr training loop only updates UI components. Adding terminal logging would help with debugging and monitoring long training runs.
🔧 Proposed fix
display_status = f"{status_text}\n{time_info}"
+ log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status_text}"
+ logger.info(log_msg)
log_lines.append(status_text)
if len(log_lines) > 15:🤖 Prompt for AI Agents
In `@acestep/gradio_ui/events/training_handlers.py` around lines 946 - 957, The
LoKr training branch updates UI state but doesn't write to the terminal;
replicate the LoRA loop's terminal logging by constructing a log message (e.g.,
combine status_text, time_info and/or log_text into a single log_msg) and call
logger.info(log_msg) before yielding display_status, log_text, loss_data,
training_state; ensure you use the same logger identifier (logger) and that
log_msg mirrors the content used in the LoRA handler so terminal logs are
consistent with UI updates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@acestep/gradio_ui/events/training_handlers.py`:
- Around line 1087-1090: The except block contains a no-op self-assignment
`latest_epoch = latest_epoch`; replace it with explicit handling: either remove
the except block entirely to preserve the previous value on parse failure, or
(preferably) set `latest_epoch = None` to explicitly mark that parsing failed.
Update the try/except around the `latest_epoch =
int(latest_checkpoint.split("_")[1])` expression accordingly so `latest_epoch`
is never silently left in an ambiguous state.
🧹 Nitpick comments (4)
acestep/training/lokr_utils.py (2)
64-99: Redundantcreate_lycoriscall whenweight_decompose=True.When
weight_decomposeis enabled,create_lycorisis called twice: first withoutdora_wd(lines 64-78), then again withdora_wd=True(lines 82-97). The first network is discarded if the second succeeds. Consider restructuring to avoid the redundant creation:♻️ Suggested refactor
- lycoris_net = create_lycoris( - decoder, - multiplier, - linear_dim=lokr_config.linear_dim, - linear_alpha=lokr_config.linear_alpha, - algo="lokr", - factor=lokr_config.factor, - decompose_both=lokr_config.decompose_both, - use_tucker=lokr_config.use_tucker, - use_scalar=lokr_config.use_scalar, - full_matrix=lokr_config.full_matrix, - bypass_mode=lokr_config.bypass_mode, - rs_lora=lokr_config.rs_lora, - unbalanced_factorization=lokr_config.unbalanced_factorization, - ) - - if lokr_config.weight_decompose: - try: - lycoris_net = create_lycoris( - decoder, - multiplier, - linear_dim=lokr_config.linear_dim, - linear_alpha=lokr_config.linear_alpha, - algo="lokr", - factor=lokr_config.factor, - decompose_both=lokr_config.decompose_both, - use_tucker=lokr_config.use_tucker, - use_scalar=lokr_config.use_scalar, - full_matrix=lokr_config.full_matrix, - bypass_mode=lokr_config.bypass_mode, - rs_lora=lokr_config.rs_lora, - unbalanced_factorization=lokr_config.unbalanced_factorization, - dora_wd=True, - ) - except Exception as exc: - logger.warning(f"DoRA mode not supported in current LyCORIS build: {exc}") + create_kwargs = dict( + linear_dim=lokr_config.linear_dim, + linear_alpha=lokr_config.linear_alpha, + algo="lokr", + factor=lokr_config.factor, + decompose_both=lokr_config.decompose_both, + use_tucker=lokr_config.use_tucker, + use_scalar=lokr_config.use_scalar, + full_matrix=lokr_config.full_matrix, + bypass_mode=lokr_config.bypass_mode, + rs_lora=lokr_config.rs_lora, + unbalanced_factorization=lokr_config.unbalanced_factorization, + ) + + if lokr_config.weight_decompose: + try: + lycoris_net = create_lycoris(decoder, multiplier, **create_kwargs, dora_wd=True) + except Exception as exc: + logger.warning(f"DoRA mode not supported in current LyCORIS build: {exc}") + lycoris_net = create_lycoris(decoder, multiplier, **create_kwargs) + else: + lycoris_net = create_lycoris(decoder, multiplier, **create_kwargs)
190-196: Redundant assignment on line 195.
metadata = metadata or Nonehas no effect here. The variablemetadatais initialized as{}(line 190), which is truthy, sometadata or Noneevaluates tometadataitself. If the intent was to passNonetosave_lokr_weightswhen metadata is empty, use an explicit check:♻️ Suggested fix
- metadata = metadata or None - save_lokr_weights(lycoris_net, output_dir, metadata=metadata) + save_lokr_weights(lycoris_net, output_dir, metadata=metadata if metadata else None)acestep/gradio_ui/events/training_handlers.py (1)
802-803: Unusedprogressparameter.The
progressparameter is declared but never used in the function body. Either remove it or use it for progress callbacks similar to other functions in this file.♻️ Option 1: Remove unused parameter
lokr_output_dir: str, training_state: Dict, - progress=None, ):acestep/training/trainer.py (1)
1084-1088: Redundantint()cast.
len()already returns anint, making the explicitint()cast unnecessary.♻️ Suggested fix
self.run_metadata = { "tensor_dir": tensor_dir, - "num_samples": int(len(data_module.train_dataset)), + "num_samples": len(data_module.train_dataset), "training_config": self.training_config.to_dict(), }
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@acestep/gradio_ui/interfaces/training.py`:
- Around line 598-611: The UI defaults for LoKr are inconsistent with the PR's
safe defaults; update the gr.Number and gr.Slider initial values so
lokr_learning_rate uses 0.003 (instead of 1e-3) and lokr_train_epochs uses 50
(instead of 500); locate the gr.Number instance named lokr_learning_rate and the
gr.Slider instance named lokr_train_epochs in the training UI and change their
value properties accordingly.
🧹 Nitpick comments (1)
acestep/gradio_ui/interfaces/training.py (1)
525-722: Consider routing LoKr labels through i18n helpers.
LoKr UI strings are hard-coded while the rest of the tab usest(...).
| with gr.Row(): | ||
| lokr_learning_rate = gr.Number( | ||
| label="Learning Rate", | ||
| value=1e-3, | ||
| info="LoKr commonly uses a higher LR than LoRA. Tune per dataset.", | ||
| ) | ||
|
|
||
| lokr_train_epochs = gr.Slider( | ||
| minimum=1, | ||
| maximum=4000, | ||
| step=1, | ||
| value=500, | ||
| label="Max Epochs", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align LoKr defaults with stated safe defaults.
PR objectives mention LR=0.003 and epochs=50, but the UI uses 1e‑3 and 500.
💡 Proposed defaults update
- lokr_learning_rate = gr.Number(
- label="Learning Rate",
- value=1e-3,
+ lokr_learning_rate = gr.Number(
+ label="Learning Rate",
+ value=3e-3,
info="LoKr commonly uses a higher LR than LoRA. Tune per dataset.",
)
@@
- lokr_train_epochs = gr.Slider(
- minimum=1,
- maximum=4000,
- step=1,
- value=500,
- label="Max Epochs",
- )
+ lokr_train_epochs = gr.Slider(
+ minimum=1,
+ maximum=4000,
+ step=1,
+ value=50,
+ label="Max Epochs",
+ )🤖 Prompt for AI Agents
In `@acestep/gradio_ui/interfaces/training.py` around lines 598 - 611, The UI
defaults for LoKr are inconsistent with the PR's safe defaults; update the
gr.Number and gr.Slider initial values so lokr_learning_rate uses 0.003 (instead
of 1e-3) and lokr_train_epochs uses 50 (instead of 500); locate the gr.Number
instance named lokr_learning_rate and the gr.Slider instance named
lokr_train_epochs in the training UI and change their value properties
accordingly.
Summary
This PR adds LoKr support alongside existing LoRA support, including:
lokr_weights.safetensors)lycoris-loraLoKr) and safer defaults (lr=0.003,epochs=50)Key Changes
acestep/training/configs.pyacestep/training/lokr_utils.pyacestep/training/trainer.pyacestep/training/__init__.pyacestep/handler.pyDataset Builder,Train LoRA,Train LoKr) + LoKr defaults:acestep/gradio_ui/interfaces/training.pyacestep/gradio_ui/events/__init__.pyacestep/gradio_ui/events/training_handlers.pyrequirements.txtpyproject.tomlValidation
py_compilepassed for touched modules.lora_output/.../final/adapter/adapter_model.safetensorslokr_output/.../final/lokr_weights.safetensorsNotes / Known Limitations
Credit to Qing Long/sdbds for initial work on his repo around this.
Summary by CodeRabbit
New Features
Improvements