-
Notifications
You must be signed in to change notification settings - Fork 569
refact: Reorganized Advanced Settings UI & Added latent shift and rescale #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Enhanced the caching mechanism for VAE audio encoding, optimizing the reuse of encoded latents. - Updated logging for better tracking of caching efficiency and latency reuse. Files changed: - acestep/handler.py: refined audio encoding logic to boost caching performance
- Improved the caching mechanism for VAE audio encoding, increasing efficiency in the reuse of encoded latents. - Enhanced logging to provide clearer insights into caching performance and latency reuse. Files changed: - acestep/handler.py: optimized audio encoding logic for better caching performance
…e generation UI - Add latent_shift and latent_rescale parameters to event handlers and batch generation - Reorganize optional parameters section with sub-headings (Music Properties, Generation Settings) - Add advanced section labels (DiT Diffusion, LM Generation, Audio Output, Automation & Batch) - Add MLX DiT i18n labels for Apple Silicon support - Update i18n files (en, zh, ja, he) with new UI labels - Move latent shift/rescale controls within generation interface layout
|
Caution Review failedThe pull request is closed. 📝 WalkthroughWalkthroughAdds two latent post-processing parameters (latent_shift, latent_rescale) threaded from the Gradio UI through event handlers and batch flows into the generation pipeline, applied to DiT latents before VAE decode; also introduces UI reorganizations, i18n keys, and emoji-driven training UI text updates. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant UI as Gradio UI
participant Events as Event Handlers
participant Gen as Generation Service
participant VAE
User->>UI: set latent_shift, latent_rescale + other params
UI->>Events: submit generation request (params)
Events->>Gen: call generate_music(params including latent_shift/rescale)
Gen->>Gen: generate DiT latents
Gen->>Gen: apply latent transform (latents * latent_rescale + latent_shift)
Gen->>VAE: decode transformed latents
VAE-->>Gen: images/audio
Gen-->>Events: results + captured params
Events-->>UI: update results/history
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds latent post-processing controls (shift/rescale) to help reduce clipping before VAE decode, and reorganizes Gradio advanced settings while improving UI/training status text and i18n coverage.
Changes:
- Introduce
latent_shift/latent_rescaleparameters end-to-end (params → handler → UI → batch restore). - Apply latent shift/rescale to
pred_latentsbefore VAE decode with optional debug logging. - Re-organize Gradio advanced settings sections and expand i18n strings; fix garbled status icons in training UI.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| acestep/inference.py | Adds latent_shift/latent_rescale to GenerationParams and forwards them into DiT generation. |
| acestep/handler.py | Extends handler generate_music API and applies latent post-processing prior to VAE decode. |
| acestep/gradio_ui/interfaces/generation.py | Reorganizes advanced UI sections; adds latent shift/rescale controls and returns them in UI state. |
| acestep/gradio_ui/i18n/en.json | Adds many new UI strings (incl. training section + latent controls). |
| acestep/gradio_ui/i18n/zh.json | Adds section headers, training strings, and latent control strings; fixes JSON comma. |
| acestep/gradio_ui/i18n/ja.json | Adds section headers and latent control strings; fixes missing comma. |
| acestep/gradio_ui/i18n/he.json | Adds section headers for reorganized UI. |
| acestep/gradio_ui/events/training_handlers.py | Replaces garbled “�” characters with intended status icons in messages. |
| acestep/gradio_ui/events/results_handlers.py | Threads latent shift/rescale through generation, param capture/restore, and batch background generation defaults. |
| acestep/gradio_ui/events/init.py | Wires new UI controls into generation wrapper inputs and param capture. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| logger.debug(f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | ||
| pred_latents = pred_latents * latent_rescale + latent_shift | ||
| if self.debug_stats: | ||
| logger.debug(f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Formatting PyTorch scalar tensors with :.4f will raise TypeError (Tensor doesn’t support that format spec). Convert reduction results to Python numbers (e.g., .item()) before formatting.
| logger.debug(f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | |
| pred_latents = pred_latents * latent_rescale + latent_shift | |
| if self.debug_stats: | |
| logger.debug(f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | |
| logger.debug( | |
| f"[generate_music] Latent BEFORE shift/rescale: " | |
| f"min={pred_latents.min().item():.4f}, " | |
| f"max={pred_latents.max().item():.4f}, " | |
| f"mean={pred_latents.mean().item():.4f}, " | |
| f"std={pred_latents.std().item():.4f}" | |
| ) | |
| pred_latents = pred_latents * latent_rescale + latent_shift | |
| if self.debug_stats: | |
| logger.debug( | |
| f"[generate_music] Latent AFTER shift/rescale: " | |
| f"min={pred_latents.min().item():.4f}, " | |
| f"max={pred_latents.max().item():.4f}, " | |
| f"mean={pred_latents.mean().item():.4f}, " | |
| f"std={pred_latents.std().item():.4f}" | |
| ) |
| logger.debug(f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | ||
| pred_latents = pred_latents * latent_rescale + latent_shift | ||
| if self.debug_stats: | ||
| logger.debug(f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even under debug_stats, calling min()/max()/mean()/std() separately triggers multiple full-tensor reductions (and on CUDA, likely multiple synchronizations). Compute these stats once per log line (e.g., via aminmax + mean + std) and reuse the results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
acestep/gradio_ui/i18n/zh.json (1)
165-166:⚠️ Potential issue | 🟡 MinorJapanese text in Chinese translation file.
Lines 165-166 contain Japanese text (
CoT メタデータ,LMを使用してCoTメタデータを生成...) instead of Chinese translations. This appears to be a copy-paste error from the Japanese translation file.🌐 Suggested Chinese translation
- "cot_metas_label": "CoT メタデータ", - "cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)", + "cot_metas_label": "CoT 元数据", + "cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
🤖 Fix all issues with AI agents
In `@acestep/handler.py`:
- Around line 3665-3673: After applying user-controlled latent post-processing
in generate_music (the pred_latents = pred_latents * latent_rescale +
latent_shift block), validate latent_shift and latent_rescale for finiteness and
reasonable magnitude, and immediately re-check pred_latents for non-finite
values; if any NaN/Inf are found, log an error via logger (include debug_stats
details if enabled) and either clamp/replace those values (e.g., with zeros or
prior safe statistics) or raise/return an explicit error to prevent passing
corrupted latents to the VAE decode—update the latent post-processing section
around pred_latents and the debug logging to perform these checks and handle
failures safely.
🧹 Nitpick comments (3)
acestep/gradio_ui/events/training_handlers.py (3)
327-327: Consider using tuple unpacking instead of concatenation.The tuple concatenation with
+works but can be simplified using iterable unpacking for better readability.♻️ Suggested refactor using tuple unpacking
if not dataset_path or not dataset_path.strip(): updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update()) - return ("❌ Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates + return ("❌ Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state, *empty_preview, *updates)if not os.path.exists(dataset_path): updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update()) - return (f"❌ Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates + return (f"❌ Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state, *empty_preview, *updates)Also applies to: 334-334
768-768: NaN check idiom is correct but could be more explicit.The
loss == losscheck leverages the IEEE 754 property that NaN ≠ NaN, which is clever but may confuse readers. Consider usingmath.isnan()for clarity.♻️ Suggested refactor for explicit NaN check
Add import at the top of the file:
import mathThen update the check:
- if step > 0 and loss is not None and loss == loss: # Check for NaN + if step > 0 and loss is not None and not math.isnan(loss):
800-800: Use explicit conversion flag for f-string.Using
{e!s}is more idiomatic than{str(e)}in f-strings.♻️ Suggested refactor
- yield f"❌ Error: {str(e)}", str(e), _training_loss_figure({}, [], []), training_state + yield f"❌ Error: {e!s}", f"{e!s}", _training_loss_figure({}, [], []), training_state- return f"❌ Export failed: {str(e)}" + return f"❌ Export failed: {e!s}"Also applies to: 862-862
| # Apply latent shift and rescale before VAE decode (for anti-clipping control) | ||
| if latent_shift != 0.0 or latent_rescale != 1.0: | ||
| logger.info(f"[generate_music] Applying latent post-processing: shift={latent_shift}, rescale={latent_rescale}") | ||
| if self.debug_stats: | ||
| logger.debug(f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | ||
| pred_latents = pred_latents * latent_rescale + latent_shift | ||
| if self.debug_stats: | ||
| logger.debug(f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-check for non-finite latents after shift/rescale.
The NaN/Inf guard runs before applying user-controlled post-processing. A non-finite or extreme latent_shift/latent_rescale can introduce NaNs/Inf and crash or corrupt VAE decode. Add validation and a post-transform finite check.
🔧 Suggested fix
# Apply latent shift and rescale before VAE decode (for anti-clipping control)
- if latent_shift != 0.0 or latent_rescale != 1.0:
+ if latent_shift is None:
+ latent_shift = 0.0
+ if latent_rescale is None:
+ latent_rescale = 1.0
+ if latent_shift != 0.0 or latent_rescale != 1.0:
+ if not math.isfinite(latent_shift) or not math.isfinite(latent_rescale):
+ raise ValueError("latent_shift/latent_rescale must be finite numbers")
logger.info(f"[generate_music] Applying latent post-processing: shift={latent_shift}, rescale={latent_rescale}")
if self.debug_stats:
logger.debug(f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}")
pred_latents = pred_latents * latent_rescale + latent_shift
if self.debug_stats:
logger.debug(f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}")
+ if torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any():
+ raise RuntimeError("Latent post-processing produced NaN/Inf values")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Apply latent shift and rescale before VAE decode (for anti-clipping control) | |
| if latent_shift != 0.0 or latent_rescale != 1.0: | |
| logger.info(f"[generate_music] Applying latent post-processing: shift={latent_shift}, rescale={latent_rescale}") | |
| if self.debug_stats: | |
| logger.debug(f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | |
| pred_latents = pred_latents * latent_rescale + latent_shift | |
| if self.debug_stats: | |
| logger.debug(f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | |
| # Apply latent shift and rescale before VAE decode (for anti-clipping control) | |
| if latent_shift is None: | |
| latent_shift = 0.0 | |
| if latent_rescale is None: | |
| latent_rescale = 1.0 | |
| if latent_shift != 0.0 or latent_rescale != 1.0: | |
| if not math.isfinite(latent_shift) or not math.isfinite(latent_rescale): | |
| raise ValueError("latent_shift/latent_rescale must be finite numbers") | |
| logger.info(f"[generate_music] Applying latent post-processing: shift={latent_shift}, rescale={latent_rescale}") | |
| if self.debug_stats: | |
| logger.debug(f"[generate_music] Latent BEFORE shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | |
| pred_latents = pred_latents * latent_rescale + latent_shift | |
| if self.debug_stats: | |
| logger.debug(f"[generate_music] Latent AFTER shift/rescale: min={pred_latents.min():.4f}, max={pred_latents.max():.4f}, mean={pred_latents.mean():.4f}, std={pred_latents.std():.4f}") | |
| if torch.isnan(pred_latents).any() or torch.isinf(pred_latents).any(): | |
| raise RuntimeError("Latent post-processing produced NaN/Inf values") |
🤖 Prompt for AI Agents
In `@acestep/handler.py` around lines 3665 - 3673, After applying user-controlled
latent post-processing in generate_music (the pred_latents = pred_latents *
latent_rescale + latent_shift block), validate latent_shift and latent_rescale
for finiteness and reasonable magnitude, and immediately re-check pred_latents
for non-finite values; if any NaN/Inf are found, log an error via logger
(include debug_stats details if enabled) and either clamp/replace those values
(e.g., with zeros or prior safe statistics) or raise/return an explicit error to
prevent passing corrupted latents to the VAE decode—update the latent
post-processing section around pred_latents and the debug logging to perform
these checks and handle failures safely.
Summary by CodeRabbit
New Features
Improvements