-
Notifications
You must be signed in to change notification settings - Fork 180
Add SD3.5-medium quantization support in ModelOpt Diffusers example #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: vipandya <[email protected]>
WalkthroughAdded support for the "sd3.5-medium" model across ONNX export and quantization flows, updated dynamic axes, input/output shape selection, pipeline creation, and registry entries; also added timing measurement for overall quantization duration. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as CLI / main
participant Manager as PipelineManager
participant Quant as Quantization
participant Export as ONNX Export
CLI->>Manager: create_pipeline(model_id)
Note over Manager: checks model type\n(including sd3-medium, sd3.5-medium)
alt SD3-family (sd3-medium / sd3.5-medium)
Manager-->>CLI: pipeline (SD3 pipeline)
else Other models
Manager-->>CLI: pipeline (other)
end
CLI->>Quant: run_quantization(pipeline, config)
Note right of CLI: start_time recorded
Quant->>Quant: calibration & quantize (uses_transformer check includes sd3.5-medium)
Quant->>Export: modelopt_export_sd(..., quant_config.quantize_mha)
Export->>Export: set dynamic axes / io shapes for sd3.5-medium
Export-->>Quant: export result
Quant-->>CLI: finished
Note right of CLI: log elapsed time
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/diffusers/quantization/quantize.py (1)
673-679
: Possible AttributeError when checking Conv quantizers.Conv modules may lack input_quantizer/weight_quantizer; direct attribute access can crash. Guard with getattr.
Apply this diff:
def _has_conv_layers(self, model: torch.nn.Module) -> bool: @@ - for module in model.modules(): - if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)) and ( - module.input_quantizer.is_enabled or module.weight_quantizer.is_enabled - ): - return True + for module in model.modules(): + if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)): + iq = getattr(module, "input_quantizer", None) + wq = getattr(module, "weight_quantizer", None) + if (getattr(iq, "is_enabled", False)) or (getattr(wq, "is_enabled", False)): + return True return False
🧹 Nitpick comments (6)
examples/diffusers/quantization/quantize.py (5)
19-19
: Minor nit: simplify import.Use plain
import time
(dropas time
) for clarity.-import time as time +import time
872-873
: Improve timing precision and log formatting (optional).Prefer perf_counter and concise message.
- s = time.time() + start = time.perf_counter() @@ - logger.info(f"Quantization process completed successfully! Time taken = {time.time() - s} seconds") + elapsed = time.perf_counter() - start + logger.info(f"Quantization completed successfully in {elapsed:.2f}s")Also applies to: 951-952
836-837
: CLI help text matches behavior?Help says “Quantizing MHA into FP8,” but code passes quantize_mha to checks for FP4 as well. Consider clarifying.
- quant_group.add_argument( - "--quantize-mha", action="store_true", help="Quantizing MHA into FP8 if its True" - ) + quant_group.add_argument( + "--quantize-mha", + action="store_true", + help="Quantize MHA when supported (FP8; FP4 path uses FP8 MHA gating)." + )
758-778
: Docs: add sd3.5 example in epilog (optional).Add a quick example alongside sd3-medium.
# FP8 quantization with ONNX export - %(prog)s --model sd3-medium --format fp8 --onnx-dir ./onnx_models/ + %(prog)s --model sd3-medium --format fp8 --onnx-dir ./onnx_models/ + # FP8 quantization with ONNX export (SD3.5 Medium) + %(prog)s --model sd3.5-medium --format fp8 --onnx-dir ./onnx_models/
127-135
: Optional: keep registry and CLI options in sync automatically.Consider deriving CLI choices from MODEL_REGISTRY keys only to prevent drift.
Also applies to: 784-786
examples/diffusers/quantization/onnx_utils/export.py (1)
320-338
: Reduce duplication with an SD3 family constant (optional).Define a set like SD3_FAMILY = {"sd3-medium","sd3.5-medium"} and reuse in conditionals.
+SD3_FAMILY = {"sd3-medium", "sd3.5-medium"} @@ - elif model_id in ["sd3-medium", "sd3.5-medium"]: + elif model_id in SD3_FAMILY: @@ - elif model_name == "sd3-medium": + elif model_name == "sd3-medium": input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"] output_names = ["sample"] - elif model_name == "sd3.5-medium": + elif model_name == "sd3.5-medium": input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"] output_names = ["out_hidden_states"]Also applies to: 416-447
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/diffusers/quantization/onnx_utils/export.py
(6 hunks)examples/diffusers/quantization/quantize.py
(9 hunks)
🔇 Additional comments (9)
examples/diffusers/quantization/quantize.py (4)
131-132
: Registry ID for SD3.5 Medium looks correct (verified).Mapping to "stabilityai/stable-diffusion-3.5-medium" is valid and compatible with StableDiffusion3Pipeline. (huggingface.co)
334-336
: Pipeline routing for SD3.5 via StableDiffusion3Pipeline is correct.Creation paths mirror SD3 Medium and align with HF usage for sd3.5-medium. (huggingface.co)
Also applies to: 365-369
119-120
: Filter function and uses_transformer updates are consistent.No issues; SD3.5 included alongside SD3 Medium.
Also applies to: 237-238
904-910
: Sanity script to verify SD3.5 wiring across repo.Run to confirm all branches handle "sd3.5-medium".
examples/diffusers/quantization/onnx_utils/export.py (5)
76-82
: sd3.5 dynamic axes added correctly.Output name uses "out_hidden_states" which matches the export branch below.
Please confirm your exported ONNX graph’s output is named "out_hidden_states" before renaming to "out.0" downstream (if applicable).
300-302
: update_dynamic_axes handles sd3.5 rename to out.0.Consistent with using "out_hidden_states" during export; good.
325-328
: Grouping sd3 and sd3.5 in dummy input generation is correct.Both use SD3Transformer2DModel and identical input tensors.
355-357
: get_io_shapes updated for sd3.5.I/O minShapes reference hidden_states as expected; OK.
Also applies to: 367-368
422-425
: Export IO names for sd3.5 align with dynamic axes.Inputs mirror sd3; output renamed to "out_hidden_states"; matches HF MMDiT-X changes.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #444 +/- ##
=======================================
Coverage 73.37% 73.37%
=======================================
Files 180 180
Lines 17937 17937
=======================================
+ Hits 13161 13162 +1
+ Misses 4776 4775 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: vipandya <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
examples/diffusers/quantization/quantize.py (1)
19-19
: Remove redundant import alias.The import statement
import time as time
is redundant. Simply useimport time
.Apply this diff:
-import time as time +import time
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/diffusers/quantization/onnx_utils/export.py
(5 hunks)examples/diffusers/quantization/quantize.py
(9 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/diffusers/quantization/quantize.py (1)
examples/diffusers/quantization/utils.py (1)
filter_func_default
(37-42)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (11)
examples/diffusers/quantization/onnx_utils/export.py (5)
76-82
: LGTM! Consistent dynamic axes configuration for sd3.5-medium.The dynamic axes configuration correctly differentiates sd3.5-medium from sd3-medium by using
"out_hidden_states"
as the output key, which aligns with the subsequent usage inupdate_dynamic_axes
(line 301) andmodelopt_export_sd
(line 422).
300-301
: LGTM! Correct output mapping for sd3.5-medium.The mapping of
"out.0"
to"out_hidden_states"
is consistent with the dynamic axes definition at lines 76-82 and the export configuration at lines 420-422.
325-328
: LGTM! Appropriate reuse of SD3 dummy input generation.Treating sd3.5-medium the same as sd3-medium for dummy input generation is appropriate since they share similar model architectures and input requirements.
355-356
: LGTM! Consistent I/O shape handling.The output name selection (
"out_hidden_states"
) and shape configuration correctly align with the dynamic axes definition and treat sd3.5-medium appropriately alongside sd3-medium.Also applies to: 364-365
420-422
: LGTM! Complete and consistent export configuration.The input/output names for sd3.5-medium are correctly specified and align with the dynamic axes definition (lines 76-82) and the
update_dynamic_axes
logic (line 301), ensuring end-to-end consistency.examples/diffusers/quantization/quantize.py (6)
63-63
: LGTM! SD35_MEDIUM model type added correctly.The new model type follows the existing naming convention and integrates properly with the model type enum.
119-119
: LGTM! Appropriate filter function mapping.Using
filter_func_default
for SD35_MEDIUM is consistent with SD3_MEDIUM and appropriate for the model architecture.
237-237
: LGTM! Correct transformer classification.SD35_MEDIUM correctly uses a transformer backbone like SD3_MEDIUM, so including it in the
uses_transformer
property is appropriate.
334-335
: LGTM! Consistent pipeline creation for SD35_MEDIUM.Both
create_pipeline_from
(static method) andcreate_pipeline
(instance method) correctly treat SD35_MEDIUM the same as SD3_MEDIUM by usingStableDiffusion3Pipeline
, ensuring consistency across the codebase.Also applies to: 365-368
872-872
: LGTM! Timing instrumentation and parameter fix.The timing measurement is correctly implemented, and line 949 fixes a subtle bug by using the instance variable
quant_config.quantize_mha
instead of the class nameQuantizationConfig.quantize_mha
, ensuring the actual configuration value is passed to the export function.Also applies to: 949-953
131-131
: No issues found. Model ID is correct and publicly accessible.The model
"stabilityai/stable-diffusion-3.5-medium"
is a Stability AI MMDiT-X text-to-image generative model available on Hugging Face. The model is available under the Stability Community License, which permits free use for research, non-commercial, and commercial use for organizations with less than $1M in annual revenue.
What does this PR do?
Type of change: Diffusers' Example Update
Overview:
Add SD3.5-medium quantization config in quantization and export files of the diffusers example
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit