Skip to content

Conversation

@riversedge
Copy link
Contributor

@riversedge riversedge commented Feb 10, 2026

Summary

This PR adds LoKr support alongside existing LoRA support, including:

  • LoKr adapter injection/training/checkpointing utilities (LyCORIS-based)
  • LoKr runtime adapter loading in the handler (lokr_weights.safetensors)
  • A dedicated LoKr training tab in Gradio (separate from LoRA)
  • Training-event wiring for LoKr start/stop/load
  • Dependency updates for lycoris-lora
  • UX/default updates: LoKr naming (LoKr) and safer defaults (lr=0.003, epochs=50)

Key Changes

  • Added LoKr config and utilities:
    • acestep/training/configs.py
    • acestep/training/lokr_utils.py
  • Added LoKr training module/trainer path:
    • acestep/training/trainer.py
  • Exposed training exports:
    • acestep/training/__init__.py
  • Added LoKr runtime adapter loading/toggling/scaling support:
    • acestep/handler.py
  • UI split into tabs (Dataset Builder, Train LoRA, Train LoKr) + LoKr defaults:
    • acestep/gradio_ui/interfaces/training.py
  • Hooked LoKr UI events:
    • acestep/gradio_ui/events/__init__.py
    • acestep/gradio_ui/events/training_handlers.py
  • Added dependency:
    • requirements.txt
    • pyproject.toml

Validation

  • Static validation:
    • py_compile passed for touched modules.
  • Smoke validation (CPU):
    • Real model initialization passes.
    • Real handler LoRA train (1 sample) completes.
    • Real handler LoKr train (1 sample) completes.
    • Handler-level LoRA/LoKr with 6-sample preprocessed dataset path completes.
  • Artifacts observed:
    • LoRA final adapter: lora_output/.../final/adapter/adapter_model.safetensors
    • LoKr final weights: lokr_output/.../final/lokr_weights.safetensors

Notes / Known Limitations

  • MPS behavior can show early non-finite gradient skips in mixed precision on some runs; training may still recover and proceed.

Credit to Qing Long/sdbds for initial work on his repo around this.

Summary by CodeRabbit

  • New Features

    • Full Train LoKR tab: dataset load, LoKR configuration, start/stop training, progress, logs, loss plot, and export.
    • LoKR training and export integrated into the training workflow.
    • Support for loading and using multiple adapter formats (LoRA and LyCORIS LoKR).
  • Improvements

    • Adapter Training UI copy and clearer start/complete/status messaging.
    • Better real-time training visualization, adapter enable/scale handling, and adapter unload/restore behavior.

@coderabbitai
Copy link

coderabbitai bot commented Feb 10, 2026

📝 Walkthrough

Walkthrough

Adds LyCORIS-based LoKR adapter support across UI, event wiring, training backend, configs, utilities, handler adapter management, and dependency updates to enable LoKR training, checkpointing, saving, loading, and export alongside existing LoRA flows.

Changes

Cohort / File(s) Summary
UI & Event Wiring
acestep/gradio_ui/interfaces/training.py, acestep/gradio_ui/events/__init__.py
Introduces a "Train LoKr" tab and Gradio components; wires dataset load, start/stop training, progress/log/plot updates, export controls, and refresh handlers for LoKR.
Training Handlers
acestep/gradio_ui/events/training_handlers.py
Adds start_lokr_training generator, lokr_training_wrapper, dataset load/export wiring, and stop handling mirroring LoRA flow while accepting LoKR-specific hyperparameters and outputs.
Trainer & Modules
acestep/training/trainer.py
Adds LoKR training path: LoKRTrainer, PreprocessedLoKRModule, Fabric/basic training loops, checkpointing, dtype/grad helpers, wrapper unwrapping, stop control, and save/export integration.
LoKR Utilities
acestep/training/lokr_utils.py
New LyCORIS integration utilities: availability check, inject_lokr_into_dit, save/load LoKR weights, and training checkpoint saving with metadata and import guards.
Configs
acestep/training/configs.py
Adds LoKRConfig dataclass with LyCORIS-compatible fields and to_dict() mapping.
Adapter Loading / Runtime Handler
acestep/handler.py, acestep/training/lora_utils.py
Extends adapter loading to detect PEFT LoRA vs LyCORIS LoKR, adds _adapter_type and _lycoris_net state, updates enable/scale flows to propagate to LyCORIS nets, and unwraps decoder wrappers with input-grad compatibility fixes.
Public API Exports
acestep/training/__init__.py
Exports LoKR public symbols (LoKRConfig, LoKRTrainer, PreprocessedLoKRModule, inject/save/load/check functions) alongside existing LoRA exports.
Deps & Manifests
pyproject.toml, requirements.txt
Adds lycoris-lora dependency to project requirements.

Sequence Diagram

sequenceDiagram
    participant UI as UI (Gradio)
    participant Events as EventHandler
    participant Trainer as LoKRTrainer
    participant Model as DiT Model
    participant LyCORIS as LyCORIS Net
    participant Disk as Disk/Storage

    UI->>Events: start_lokr_training(tensor_dir, lokr_config, params)
    Events->>Events: validate inputs, update training_state
    Events->>Trainer: init(dit_handler, LoKRConfig, TrainingConfig)
    Trainer->>Model: unwrap decoder wrappers
    Trainer->>LyCORIS: inject_lokr_into_dit(model, lokr_config)
    LyCORIS->>Model: attach LoKR adapters, freeze base params
    Trainer->>Trainer: setup data module & memory features

    loop epochs
        Trainer->>Disk: load preprocessed tensors
        loop batches
            Trainer->>Model: forward (with LoKR)
            Model->>LyCORIS: apply LoKR adaptations
            Trainer->>Trainer: backward, accumulate, clip
            Trainer->>Events: yield progress, loss, logs
        end
        Trainer->>Disk: save_lokr_training_checkpoint()
    end

    Trainer->>Disk: save_lokr_weights(output_dir) 
    Trainer->>Events: emit completion status
    Events->>UI: update progress/log/export status
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • ChuxiJ

Poem

🐰 I hopped through tensors, stitch by stitch,

LyCORIS hummed and warmed each niche.
Checkpoints saved and weights took flight,
Adapters whispered through the night.
Hooray — a rabbit cheers: "Train bright!"

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'New Feature: LoKr Training' clearly and concisely summarizes the main change—adding LoKr adapter training support alongside LoRA. It accurately reflects the primary objective of this substantial pull request.
Docstring Coverage ✅ Passed Docstring coverage is 82.35% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
acestep/gradio_ui/events/__init__.py (1)

1196-1336: ⚠️ Potential issue | 🟡 Minor

Remove redundant logger imports inside both wrapper functions.

The module already imports logger at the top level (line 7); the inner imports in both training_wrapper (line 1204) and lokr_training_wrapper (line 1262) redefine it unnecessarily and trigger F811 violations.

🛠️ Suggested fixes
    # Start training from preprocessed tensors
    def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts):
-       from loguru import logger
        if not isinstance(ts, dict):
    # Start LoKR training from preprocessed tensors
    def lokr_training_wrapper(
        tensor_dir,
        ...
    ):
-       from loguru import logger
        if not isinstance(ts, dict):
🤖 Fix all issues with AI agents
In `@acestep/training/trainer.py`:
- Around line 1321-1406: The _train_basic function lacks checks for non-finite
gradients before stepping; add the same guard used in _train_with_fabric: after
torch.nn.utils.clip_grad_norm_(trainable_params,
self.training_config.max_grad_norm) inspect gradients for non-finite values
(e.g., any torch.isnan or torch.isinf across p.grad for p in trainable_params)
and if non-finite, call optimizer.zero_grad(set_to_none=True), skip
optimizer.step() and scheduler.step(), emit/ yield a log message about skipping
the step, and avoid updating global_step/accumulated counters; apply this change
in both the inner accumulation flush (where accumulation_step >=
gradient_accumulation_steps) and the final remainder block (if accumulation_step
> 0), referencing _train_basic, trainable_params, optimizer, scheduler,
accumulated_loss, accumulation_step, and global_step.
🧹 Nitpick comments (12)
requirements.txt (1)

34-39: Pin lycoris-lora to a tested version in both requirements.txt and pyproject.toml.

Unpinned training dependencies can drift and break reproducibility. Use lycoris-lora>=3.4.0 (current stable) or a tighter constraint once you confirm a working version, keeping both configuration files in sync.

acestep/training/__init__.py (2)

1-6: Update module docstring to reflect LoKR support.

The module docstring only mentions LoRA training, but this module now also exposes LoKR training functionality. Consider updating for accuracy.

📝 Suggested docstring update
 """
 ACE-Step Training Module
 
-This module provides LoRA training functionality for ACE-Step models,
-including dataset building, audio labeling, and training utilities.
+This module provides LoRA and LoKr adapter training functionality for ACE-Step
+models, including dataset building, audio labeling, and training utilities.
 """

54-63: LoKR utilities grouped under LoRA comment section.

The LoKR utility exports (lines 60-63) are placed under the # LoRA Utils comment. Consider adding a separate comment section for clarity.

📝 Suggested organization
     # LoRA Utils
     "inject_lora_into_dit",
     "save_lora_weights",
     "load_lora_weights",
     "merge_lora_weights",
     "check_peft_available",
+    # LoKR Utils
     "inject_lokr_into_dit",
     "save_lokr_weights",
     "load_lokr_weights",
     "check_lycoris_available",
acestep/gradio_ui/interfaces/training.py (2)

653-660: LoKr tab missing resume checkpoint support.

The LoRA tab includes a resume checkpoint field (lines 465-469), but LoKr training has no equivalent. If this is intentional for the initial release, consider adding a TODO or noting it as a future enhancement.

📝 Suggested addition for feature parity
                 with gr.Row():
                     lokr_output_dir = gr.Textbox(
                         label="Output Directory",
                         value="./lokr_output",
                         placeholder="./lokr_output",
                         info="Where LoKr checkpoints and final weights will be written.",
                     )
+
+                with gr.Row():
+                    lokr_resume_checkpoint_dir = gr.Textbox(
+                        label="Resume Checkpoint (optional)",
+                        placeholder="./lokr_output/checkpoints/epoch_200",
+                        info="Directory of a saved LoKr checkpoint to resume from",
+                    )

661-698: LoKr tab missing export functionality.

The LoRA tab has an export section (lines 510-523) allowing users to export trained weights. The LoKr tab lacks this feature. If this is intentional for the initial release, consider documenting this limitation or adding it for parity.

acestep/gradio_ui/events/training_handlers.py (4)

801-802: Unused progress parameter.

The progress parameter is declared but never used in start_lokr_training. Either remove it or wire it up for progress callbacks similar to other functions.

📝 Option 1: Remove unused parameter
     lokr_output_dir: str,
     training_state: Dict,
-    progress=None,
 ):

829-833: Remove unused noqa directive.

The # noqa: F401 comment is unnecessary since Fabric is used implicitly to verify the import succeeded.

📝 Suggested fix
     try:
-        from lightning.fabric import Fabric  # noqa: F401
+        from lightning.fabric import Fabric
     except ImportError as e:

817-827: LoKr missing low VRAM tier warnings.

The LoRA training path (lines 556-572) includes warnings for CPU-only and low VRAM tier scenarios. The LoKr path lacks these warnings, which could lead to users being unaware of suboptimal training conditions.

📝 Add VRAM tier warnings for LoKr
     if getattr(dit_handler, "quantization", None) is not None:
+        gpu_config = get_global_gpu_config()
+        if gpu_config.gpu_memory_gb <= 0:
+            yield (
+                "WARNING: CPU-only training detected. Using best-effort training path "
+                "(non-quantized DiT). Performance will be sub-optimal.",
+                "",
+                None,
+                training_state,
+            )
+        elif gpu_config.tier in {"tier1", "tier2", "tier3", "tier4"}:
+            yield (
+                f"WARNING: Low VRAM tier detected ({gpu_config.gpu_memory_gb:.1f} GB, {gpu_config.tier}). "
+                "Using best-effort training path (non-quantized DiT). Performance may be sub-optimal.",
+                "",
+                None,
+                training_state,
+            )
+
         yield "Switching model to training preset (disable quantization)...", "", None, training_state

782-981: Significant code duplication between LoRA and LoKr handlers.

The start_lokr_training function (~200 lines) is largely a copy of start_training with minor differences (config class, trainer class, messages). Consider extracting shared logic into a helper function to improve maintainability.

This is acceptable for the initial implementation but should be considered for future refactoring.

acestep/training/trainer.py (3)

168-214: Silent exception swallowing in _configure_training_memory_features.

Multiple try-except-pass blocks silently swallow exceptions (lines 184-185, 190-191, 202-203, 211-212). While this is intentional for graceful degradation when features aren't available, consider logging at debug level to aid troubleshooting.

📝 Suggested logging improvement
         if hasattr(mod, "gradient_checkpointing_enable"):
             try:
                 mod.gradient_checkpointing_enable()
                 checkpointing_enabled = True
             except Exception:
-                pass
+                logger.debug(f"gradient_checkpointing_enable failed on {type(mod).__name__}")

Apply similar pattern to other except blocks.


914-921: Unused dtype parameter in PreprocessedLoKRModule.__init__.

The dtype parameter is declared but unused since compute dtype is derived from device type via _select_compute_dtype. Either remove the parameter or use it to allow explicit dtype override.

📝 Option 1: Remove unused parameter
     def __init__(
         self,
         model: nn.Module,
         lokr_config: LoKRConfig,
         training_config: TrainingConfig,
         device: torch.device,
-        dtype: torch.dtype,
     ):

Note: This would require updating the caller in LoKRTrainer.train_from_preprocessed (lines 1052-1058).


911-994: Significant code duplication between LoRA and LoKR implementations.

PreprocessedLoKRModule is nearly identical to PreprocessedLoRAModule, and LoKRTrainer is nearly identical to LoRATrainer. Consider extracting a base class or shared helper functions to reduce duplication and ease future maintenance.

This is acceptable for the initial implementation but should be considered for future refactoring to prevent the classes from diverging in subtle ways.

Also applies to: 997-1435

Comment on lines 1321 to 1406
def _train_basic(
self,
data_module: PreprocessedDataModule,
training_state: Optional[Dict],
) -> Generator[Tuple[int, float, str], None, None]:
yield 0, 0.0, "🚀 Starting basic training loop..."
os.makedirs(self.training_config.output_dir, exist_ok=True)

train_loader = data_module.train_dataloader()
trainable_params = [p for p in self.module.model.parameters() if p.requires_grad]
if not trainable_params:
yield 0, 0.0, "❌ No trainable parameters found!"
return

optimizer = AdamW(
trainable_params,
lr=self.training_config.learning_rate,
weight_decay=self.training_config.weight_decay,
)
steps_per_epoch = max(1, math.ceil(len(train_loader) / self.training_config.gradient_accumulation_steps))
total_steps = steps_per_epoch * self.training_config.max_epochs
warmup_steps = min(self.training_config.warmup_steps, max(1, total_steps // 10))

warmup_scheduler = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps)
main_scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=max(1, total_steps - warmup_steps),
T_mult=1,
eta_min=self.training_config.learning_rate * 0.01,
)
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_steps])

global_step = 0
accumulation_step = 0
accumulated_loss = 0.0
optimizer.zero_grad(set_to_none=True)
self.module.model.decoder.train()

for epoch in range(self.training_config.max_epochs):
epoch_loss = 0.0
num_updates = 0
epoch_start_time = time.time()

for batch in train_loader:
if training_state and training_state.get("should_stop", False):
yield global_step, accumulated_loss / max(accumulation_step, 1), "⏹️ Training stopped"
return

loss = self.module.training_step(batch)
loss = loss / self.training_config.gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
accumulation_step += 1

if accumulation_step >= self.training_config.gradient_accumulation_steps:
torch.nn.utils.clip_grad_norm_(trainable_params, self.training_config.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1

avg_loss = accumulated_loss / accumulation_step
if global_step % self.training_config.log_every_n_steps == 0:
yield global_step, avg_loss, f"Epoch {epoch+1}, Step {global_step}, Loss: {avg_loss:.4f}"

epoch_loss += avg_loss
num_updates += 1
accumulated_loss = 0.0
accumulation_step = 0

if accumulation_step > 0:
torch.nn.utils.clip_grad_norm_(trainable_params, self.training_config.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1

avg_loss = accumulated_loss / accumulation_step
if global_step % self.training_config.log_every_n_steps == 0:
yield global_step, avg_loss, f"Epoch {epoch+1}, Step {global_step}, Loss: {avg_loss:.4f}"

epoch_loss += avg_loss
num_updates += 1
accumulated_loss = 0.0
accumulation_step = 0

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

LoKRTrainer._train_basic missing non-finite gradient handling.

The _train_with_fabric method (lines 1216-1226, 1255-1263) includes non-finite gradient detection and skipping, but _train_basic (lines 1375-1376, 1391-1392) only uses clip_grad_norm_ without the guard. This could lead to NaN propagation when training without Fabric.

🐛 Suggested fix to add gradient guards
                 if accumulation_step >= self.training_config.gradient_accumulation_steps:
+                    nonfinite_grads, grad_tensors = _count_nonfinite_grads(trainable_params)
+                    if nonfinite_grads > 0:
+                        optimizer.zero_grad(set_to_none=True)
+                        yield global_step, float("nan"), (
+                            f"⚠️ Non-finite gradients ({nonfinite_grads}/{grad_tensors}); "
+                            "skipping optimizer step"
+                        )
+                        accumulated_loss = 0.0
+                        accumulation_step = 0
+                        continue
+
                     torch.nn.utils.clip_grad_norm_(trainable_params, self.training_config.max_grad_norm)

Apply similar pattern to the epoch remainder block (lines 1391-1405).

🤖 Prompt for AI Agents
In `@acestep/training/trainer.py` around lines 1321 - 1406, The _train_basic
function lacks checks for non-finite gradients before stepping; add the same
guard used in _train_with_fabric: after
torch.nn.utils.clip_grad_norm_(trainable_params,
self.training_config.max_grad_norm) inspect gradients for non-finite values
(e.g., any torch.isnan or torch.isinf across p.grad for p in trainable_params)
and if non-finite, call optimizer.zero_grad(set_to_none=True), skip
optimizer.step() and scheduler.step(), emit/ yield a log message about skipping
the step, and avoid updating global_step/accumulated counters; apply this change
in both the inner accumulation flush (where accumulation_step >=
gradient_accumulation_steps) and the final remainder block (if accumulation_step
> 0), referencing _train_basic, trainable_params, optimizer, scheduler,
accumulated_loss, accumulation_step, and global_step.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@acestep/gradio_ui/events/__init__.py`:
- Around line 1243-1289: In lokr_training_wrapper remove the local "from loguru
import logger" import that shadows the module-level logger; use the module-level
logger (the one imported at top of the file) when calling logger.exception in
the except block and keep the rest of the function (train_h.start_lokr_training
loop, dit_handler usage, and ts handling) unchanged.

In `@acestep/gradio_ui/events/training_handlers.py`:
- Line 830: Remove the unnecessary linter suppression on the import of Fabric —
locate the line importing Fabric (from lightning.fabric import Fabric) in
training_handlers.py and delete the trailing "# noqa: F401" so the import line
is clean (leave the import itself intact if used elsewhere; only remove the
unused noqa directive).
- Around line 946-957: The LoKr training branch updates UI state but doesn't
write to the terminal; replicate the LoRA loop's terminal logging by
constructing a log message (e.g., combine status_text, time_info and/or log_text
into a single log_msg) and call logger.info(log_msg) before yielding
display_status, log_text, loss_data, training_state; ensure you use the same
logger identifier (logger) and that log_msg mirrors the content used in the LoRA
handler so terminal logs are consistent with UI updates.
🧹 Nitpick comments (6)
acestep/gradio_ui/events/__init__.py (1)

1291-1318: LoKr training lacks resume checkpoint support.

Unlike the LoRA training wiring (lines 1216-1240) which includes resume_checkpoint_dir, the LoKr training button wiring doesn't support resuming from checkpoints. This may be intentional for the initial release, but worth noting as a potential follow-up enhancement.

acestep/gradio_ui/events/training_handlers.py (2)

843-878: Consider extracting device configuration logic to reduce duplication.

The device detection and dataloader configuration logic (lines 843-878) is nearly identical to the LoRA training function (lines 608-644). Extracting this to a shared helper would improve maintainability.

♻️ Example helper function
def _get_device_training_config(dit_handler) -> dict:
    """Get device-specific training configuration for DataLoader and mixed precision."""
    device_attr = getattr(dit_handler, "device", "")
    if hasattr(device_attr, "type"):
        device_type = str(device_attr.type).lower()
    else:
        device_type = str(device_attr).split(":", 1)[0].lower()

    if device_type == "cuda":
        return {
            "num_workers": 4,
            "pin_memory": True,
            "prefetch_factor": 2,
            "persistent_workers": True,
            "pin_memory_device": "cuda",
            "mixed_precision": "bf16",
        }
    elif device_type == "xpu":
        return {
            "num_workers": 4,
            "pin_memory": True,
            "prefetch_factor": 2,
            "persistent_workers": True,
            "pin_memory_device": None,
            "mixed_precision": "bf16",
        }
    elif device_type == "mps":
        return {
            "num_workers": 0,
            "pin_memory": False,
            "prefetch_factor": 2,
            "persistent_workers": False,
            "pin_memory_device": None,
            "mixed_precision": "fp16",
        }
    else:
        return {
            "num_workers": 0,
            "pin_memory": False,
            "prefetch_factor": 2,
            "persistent_workers": False,
            "pin_memory_device": None,
            "mixed_precision": "fp32",
        }

817-827: LoKr training missing VRAM tier warnings present in LoRA training.

The LoRA training path (lines 555-582) includes helpful warnings for CPU-only and low VRAM tier scenarios. The LoKr training path lacks these warnings, which could leave users without important context about expected performance.

acestep/gradio_ui/interfaces/training.py (3)

525-533: LoKr tab missing internationalization (i18n) support.

The LoKr tab uses hardcoded English strings (e.g., "📊 Preprocessed Tensors", "Dataset Info") while the LoRA tab uses the t() function for translations. This inconsistency means the LoKr tab won't be localized.

Consider adding i18n keys for the LoKr tab labels and using t("training.lokr_*") pattern to match the LoRA tab's approach. This can be addressed in a follow-up PR if localization is a priority.


605-611: LoKr epochs slider doesn't respect DEBUG_TRAINING flag.

The LoRA tab uses epoch_min, epoch_step, and epoch_default variables (lines 29-31) which adjust based on the DEBUG_TRAINING flag. The LoKr epochs slider has hardcoded values (min=1, step=1, value=50), creating an inconsistency in debug behavior.

♻️ Proposed fix for consistency
                     lokr_train_epochs = gr.Slider(
-                        minimum=1,
+                        minimum=epoch_min,
                         maximum=4000,
-                        step=1,
-                        value=50,
+                        step=epoch_step,
+                        value=50 if debug_training_enabled else 50,  # LoKr default stays at 50
                         label="Max Epochs",
                     )

Note: You may want a separate lokr_epoch_default if the LoKr default should differ from LoRA in non-debug mode.


629-636: Save frequency may conflict with low default epoch count.

With lokr_train_epochs defaulting to 50 and lokr_save_every_n_epochs having minimum=50, users can only save checkpoints at the very end of training by default. This differs from LoRA where higher epoch counts allow multiple checkpoint saves.

Consider either:

  1. Lowering the minimum for lokr_save_every_n_epochs (e.g., to 10)
  2. Adding a note in the UI info text about this behavior

Comment on lines +1243 to +1289
def lokr_training_wrapper(
tensor_dir,
linear_dim,
linear_alpha,
factor,
decompose_both,
use_tucker,
use_scalar,
weight_decompose,
lr,
ep,
bs,
ga,
se,
sh,
sd,
od,
ts,
):
from loguru import logger
if not isinstance(ts, dict):
ts = {"is_training": False, "should_stop": False}
try:
for progress, log_msg, plot, state in train_h.start_lokr_training(
tensor_dir,
dit_handler,
linear_dim,
linear_alpha,
factor,
decompose_both,
use_tucker,
use_scalar,
weight_decompose,
lr,
ep,
bs,
ga,
se,
sh,
sd,
od,
ts,
):
yield progress, log_msg, plot, state
except Exception as e:
logger.exception("LoKR training wrapper error")
yield f"❌ Error: {str(e)}", str(e), None, ts
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove redundant logger import inside function.

The logger import at line 1262 shadows the module-level import from line 7. This is unnecessary and flagged by static analysis (F811).

🔧 Proposed fix
     def lokr_training_wrapper(
         tensor_dir,
         linear_dim,
         linear_alpha,
         factor,
         decompose_both,
         use_tucker,
         use_scalar,
         weight_decompose,
         lr,
         ep,
         bs,
         ga,
         se,
         sh,
         sd,
         od,
         ts,
     ):
-        from loguru import logger
         if not isinstance(ts, dict):
             ts = {"is_training": False, "should_stop": False}
         try:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def lokr_training_wrapper(
tensor_dir,
linear_dim,
linear_alpha,
factor,
decompose_both,
use_tucker,
use_scalar,
weight_decompose,
lr,
ep,
bs,
ga,
se,
sh,
sd,
od,
ts,
):
from loguru import logger
if not isinstance(ts, dict):
ts = {"is_training": False, "should_stop": False}
try:
for progress, log_msg, plot, state in train_h.start_lokr_training(
tensor_dir,
dit_handler,
linear_dim,
linear_alpha,
factor,
decompose_both,
use_tucker,
use_scalar,
weight_decompose,
lr,
ep,
bs,
ga,
se,
sh,
sd,
od,
ts,
):
yield progress, log_msg, plot, state
except Exception as e:
logger.exception("LoKR training wrapper error")
yield f"❌ Error: {str(e)}", str(e), None, ts
def lokr_training_wrapper(
tensor_dir,
linear_dim,
linear_alpha,
factor,
decompose_both,
use_tucker,
use_scalar,
weight_decompose,
lr,
ep,
bs,
ga,
se,
sh,
sd,
od,
ts,
):
if not isinstance(ts, dict):
ts = {"is_training": False, "should_stop": False}
try:
for progress, log_msg, plot, state in train_h.start_lokr_training(
tensor_dir,
dit_handler,
linear_dim,
linear_alpha,
factor,
decompose_both,
use_tucker,
use_scalar,
weight_decompose,
lr,
ep,
bs,
ga,
se,
sh,
sd,
od,
ts,
):
yield progress, log_msg, plot, state
except Exception as e:
logger.exception("LoKR training wrapper error")
yield f"❌ Error: {str(e)}", str(e), None, ts
🧰 Tools
🪛 Ruff (0.14.14)

[error] 1262-1262: Redefinition of unused logger from line 7: logger redefined here

Remove definition: logger

(F811)


[warning] 1287-1287: Do not catch blind exception: Exception

(BLE001)


[warning] 1289-1289: Use explicit conversion flag

Replace with conversion flag

(RUF010)

🤖 Prompt for AI Agents
In `@acestep/gradio_ui/events/__init__.py` around lines 1243 - 1289, In
lokr_training_wrapper remove the local "from loguru import logger" import that
shadows the module-level logger; use the module-level logger (the one imported
at top of the file) when calling logger.exception in the except block and keep
the rest of the function (train_h.start_lokr_training loop, dit_handler usage,
and ts handling) unchanged.

return

try:
from lightning.fabric import Fabric # noqa: F401
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused noqa directive.

The # noqa: F401 comment is unnecessary since F401 isn't enabled in the linter configuration.

🔧 Proposed fix
     try:
-        from lightning.fabric import Fabric  # noqa: F401
+        from lightning.fabric import Fabric
     except ImportError as e:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from lightning.fabric import Fabric # noqa: F401
try:
from lightning.fabric import Fabric
except ImportError as e:
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 830-830: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In `@acestep/gradio_ui/events/training_handlers.py` at line 830, Remove the
unnecessary linter suppression on the import of Fabric — locate the line
importing Fabric (from lightning.fabric import Fabric) in training_handlers.py
and delete the trailing "# noqa: F401" so the import line is clean (leave the
import itself intact if used elsewhere; only remove the unused noqa directive).

Comment on lines +946 to +957
display_status = f"{status_text}\n{time_info}"
log_lines.append(status_text)
if len(log_lines) > 15:
log_lines = log_lines[-15:]
log_text = "\n".join(log_lines)

if step > 0 and loss is not None and loss == loss:
step_list.append(step)
loss_list.append(float(loss))
loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})

yield display_status, log_text, loss_data, training_state
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

LoKr training loop missing terminal logging.

Unlike the LoRA training loop (line 724) which logs to the terminal via logger.info(log_msg), the LoKr training loop only updates UI components. Adding terminal logging would help with debugging and monitoring long training runs.

🔧 Proposed fix
             display_status = f"{status_text}\n{time_info}"
+            log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status_text}"
+            logger.info(log_msg)
             log_lines.append(status_text)
             if len(log_lines) > 15:
🤖 Prompt for AI Agents
In `@acestep/gradio_ui/events/training_handlers.py` around lines 946 - 957, The
LoKr training branch updates UI state but doesn't write to the terminal;
replicate the LoRA loop's terminal logging by constructing a log message (e.g.,
combine status_text, time_info and/or log_text into a single log_msg) and call
logger.info(log_msg) before yielding display_status, log_text, loss_data,
training_state; ensure you use the same logger identifier (logger) and that
log_msg mirrors the content used in the LoRA handler so terminal logs are
consistent with UI updates.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@acestep/gradio_ui/events/training_handlers.py`:
- Around line 1087-1090: The except block contains a no-op self-assignment
`latest_epoch = latest_epoch`; replace it with explicit handling: either remove
the except block entirely to preserve the previous value on parse failure, or
(preferably) set `latest_epoch = None` to explicitly mark that parsing failed.
Update the try/except around the `latest_epoch =
int(latest_checkpoint.split("_")[1])` expression accordingly so `latest_epoch`
is never silently left in an ambiguous state.
🧹 Nitpick comments (4)
acestep/training/lokr_utils.py (2)

64-99: Redundant create_lycoris call when weight_decompose=True.

When weight_decompose is enabled, create_lycoris is called twice: first without dora_wd (lines 64-78), then again with dora_wd=True (lines 82-97). The first network is discarded if the second succeeds. Consider restructuring to avoid the redundant creation:

♻️ Suggested refactor
-    lycoris_net = create_lycoris(
-        decoder,
-        multiplier,
-        linear_dim=lokr_config.linear_dim,
-        linear_alpha=lokr_config.linear_alpha,
-        algo="lokr",
-        factor=lokr_config.factor,
-        decompose_both=lokr_config.decompose_both,
-        use_tucker=lokr_config.use_tucker,
-        use_scalar=lokr_config.use_scalar,
-        full_matrix=lokr_config.full_matrix,
-        bypass_mode=lokr_config.bypass_mode,
-        rs_lora=lokr_config.rs_lora,
-        unbalanced_factorization=lokr_config.unbalanced_factorization,
-    )
-
-    if lokr_config.weight_decompose:
-        try:
-            lycoris_net = create_lycoris(
-                decoder,
-                multiplier,
-                linear_dim=lokr_config.linear_dim,
-                linear_alpha=lokr_config.linear_alpha,
-                algo="lokr",
-                factor=lokr_config.factor,
-                decompose_both=lokr_config.decompose_both,
-                use_tucker=lokr_config.use_tucker,
-                use_scalar=lokr_config.use_scalar,
-                full_matrix=lokr_config.full_matrix,
-                bypass_mode=lokr_config.bypass_mode,
-                rs_lora=lokr_config.rs_lora,
-                unbalanced_factorization=lokr_config.unbalanced_factorization,
-                dora_wd=True,
-            )
-        except Exception as exc:
-            logger.warning(f"DoRA mode not supported in current LyCORIS build: {exc}")
+    create_kwargs = dict(
+        linear_dim=lokr_config.linear_dim,
+        linear_alpha=lokr_config.linear_alpha,
+        algo="lokr",
+        factor=lokr_config.factor,
+        decompose_both=lokr_config.decompose_both,
+        use_tucker=lokr_config.use_tucker,
+        use_scalar=lokr_config.use_scalar,
+        full_matrix=lokr_config.full_matrix,
+        bypass_mode=lokr_config.bypass_mode,
+        rs_lora=lokr_config.rs_lora,
+        unbalanced_factorization=lokr_config.unbalanced_factorization,
+    )
+
+    if lokr_config.weight_decompose:
+        try:
+            lycoris_net = create_lycoris(decoder, multiplier, **create_kwargs, dora_wd=True)
+        except Exception as exc:
+            logger.warning(f"DoRA mode not supported in current LyCORIS build: {exc}")
+            lycoris_net = create_lycoris(decoder, multiplier, **create_kwargs)
+    else:
+        lycoris_net = create_lycoris(decoder, multiplier, **create_kwargs)

190-196: Redundant assignment on line 195.

metadata = metadata or None has no effect here. The variable metadata is initialized as {} (line 190), which is truthy, so metadata or None evaluates to metadata itself. If the intent was to pass None to save_lokr_weights when metadata is empty, use an explicit check:

♻️ Suggested fix
-    metadata = metadata or None
-    save_lokr_weights(lycoris_net, output_dir, metadata=metadata)
+    save_lokr_weights(lycoris_net, output_dir, metadata=metadata if metadata else None)
acestep/gradio_ui/events/training_handlers.py (1)

802-803: Unused progress parameter.

The progress parameter is declared but never used in the function body. Either remove it or use it for progress callbacks similar to other functions in this file.

♻️ Option 1: Remove unused parameter
     lokr_output_dir: str,
     training_state: Dict,
-    progress=None,
 ):
acestep/training/trainer.py (1)

1084-1088: Redundant int() cast.

len() already returns an int, making the explicit int() cast unnecessary.

♻️ Suggested fix
         self.run_metadata = {
             "tensor_dir": tensor_dir,
-            "num_samples": int(len(data_module.train_dataset)),
+            "num_samples": len(data_module.train_dataset),
             "training_config": self.training_config.to_dict(),
         }

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@acestep/gradio_ui/interfaces/training.py`:
- Around line 598-611: The UI defaults for LoKr are inconsistent with the PR's
safe defaults; update the gr.Number and gr.Slider initial values so
lokr_learning_rate uses 0.003 (instead of 1e-3) and lokr_train_epochs uses 50
(instead of 500); locate the gr.Number instance named lokr_learning_rate and the
gr.Slider instance named lokr_train_epochs in the training UI and change their
value properties accordingly.
🧹 Nitpick comments (1)
acestep/gradio_ui/interfaces/training.py (1)

525-722: Consider routing LoKr labels through i18n helpers.
LoKr UI strings are hard-coded while the rest of the tab uses t(...).

Comment on lines +598 to +611
with gr.Row():
lokr_learning_rate = gr.Number(
label="Learning Rate",
value=1e-3,
info="LoKr commonly uses a higher LR than LoRA. Tune per dataset.",
)

lokr_train_epochs = gr.Slider(
minimum=1,
maximum=4000,
step=1,
value=500,
label="Max Epochs",
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Align LoKr defaults with stated safe defaults.
PR objectives mention LR=0.003 and epochs=50, but the UI uses 1e‑3 and 500.

💡 Proposed defaults update
-                    lokr_learning_rate = gr.Number(
-                        label="Learning Rate",
-                        value=1e-3,
+                    lokr_learning_rate = gr.Number(
+                        label="Learning Rate",
+                        value=3e-3,
                         info="LoKr commonly uses a higher LR than LoRA. Tune per dataset.",
                     )
@@
-                    lokr_train_epochs = gr.Slider(
-                        minimum=1,
-                        maximum=4000,
-                        step=1,
-                        value=500,
-                        label="Max Epochs",
-                    )
+                    lokr_train_epochs = gr.Slider(
+                        minimum=1,
+                        maximum=4000,
+                        step=1,
+                        value=50,
+                        label="Max Epochs",
+                    )
🤖 Prompt for AI Agents
In `@acestep/gradio_ui/interfaces/training.py` around lines 598 - 611, The UI
defaults for LoKr are inconsistent with the PR's safe defaults; update the
gr.Number and gr.Slider initial values so lokr_learning_rate uses 0.003 (instead
of 1e-3) and lokr_train_epochs uses 50 (instead of 500); locate the gr.Number
instance named lokr_learning_rate and the gr.Slider instance named
lokr_train_epochs in the training UI and change their value properties
accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant