-
Notifications
You must be signed in to change notification settings - Fork 60
WIP: Feat: Add ONNX Sub Functions Export Feature #613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Auto-detect decoder layers for export_modules_as_functions based on model type - Add CustomOpTransform to dynamically register and include custom ops (CustomRMSNorm, CtxGather, CtxScatter) - Fix invalid INT32_MAX indices in ONNX runtime by replacing with 0 - Support ONNX functions export via QEFF_USE_ONNX_FUNCTIONS env var - Handle rope_scaling None values gracefully for Gemma3 Signed-off-by: vbaddi <[email protected]>
Signed-off-by: vbaddi <[email protected]>
Signed-off-by: vbaddi <[email protected]>
Signed-off-by: Vinayak Baddi <[email protected]>
Signed-off-by: Vinayak Baddi <[email protected]>
Signed-off-by: Vinayak Baddi <[email protected]>
| """ | ||
| transformed = False | ||
| onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) | ||
| temp_onnx_path = kwargs.get("temp_onnx_path", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make it as a mandiatory argument? and onnx_base_dir is unused here
| :param temp_onnx_path: Path to save the slimmed ONNX model. | ||
| """ | ||
| transformed = False | ||
| onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if OnnxSlimTransform is called do you need to again have a flag for onnx_slim_transform = True? and then check it on line 130? expectation should be to apply the onnxslimtransform right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove it from here. There is a flag called "enable_onnx_slim_transform" lets users decide whether to enable ONNX Slim. We can add a condition in modeling_auto so that this transform is included in the _onnx_transform list only when the flag is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this change with GPTOSS, it fails in the onnxslim transform. Discussed with VB that this doesn't help us much.
Lets not add extra package dependency if it has limited use.
Let's remove onnxslim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you provide the error log? I think there should be a 5% performance gain with onnxslim.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@inisis thanks. you mean ~5% gain in perf.? w/onnxslim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about GPT OSS but for Qwen 2.5 VL we observed that onnxslim removes identical nodes which lead to creation of dummy nodes.
@abhishek-singh591 Hi, what's that dummy nodes, it should not be created by onnxslim, the removes of identical nodes is generally known as CSE, it reduces extra computation, I thinks it's very useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nodes left behind after removing identical can be considered dummy/orphan nodes. These are nodes that were originally connected as outputs to the identical nodes but, after CSE they no longer have valid connections. Ideally, CSE should rewire the inputs and outputs properly so that no orphan nodes remain right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, onnxslim will remove those dummy nodes automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, onnxslim will remove those dummy nodes automatically.
Actually, it's not doing and we also don’t want to delete those nodes. Before removing identical nodes through CSE, it should connect the input of the identity node directly to its output, ensuring the graph remains valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, onnxslim will remove those dummy nodes automatically.
Actually, it's not doing and we also don’t want to delete those nodes. Before removing identical nodes through CSE, it should connect the input of the identity node directly to its output, ensuring the graph remains valid.
Really, can you provide me an example, many thanks.
| inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) | ||
|
|
||
| if hasattr(config, "rope_scaling") and "factor" in config.rope_scaling: | ||
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None and "factor" in config.rope_scaling: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change part of ONNX Sub Functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but the correct modeling representation changes.
| example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32)) | ||
| dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes | ||
| output_names.append(f"past_{kv}.{i}_RetainedState") | ||
| output_names.append(f"past_{kv}.{i}_InternalRetainedState") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we are renaming it? if we are renaming _RetainedState to _InternalRetainedState wouldnt the chages need to added on text_generation_inference and other places we are skipping the bufferes? Even if we are not enabling the subfunction this would impact regular execution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ONNX_EXPORT_EXAMPLE_FBS = 4 | ||
| ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep | ||
| ONNX_EXPORT_OPSET = 13 | ||
| ONNX_EXPORT_OPSET = 17 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some test on opset 17 is still ongoing @quic-hemagnih are we good to merge opset 17 changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but export_module_as_function has a hard constraint that opset must be >=15.
|
|
||
| # Apply patches | ||
| # TODO: Find a better way to do this, this is temp. fix. | ||
| apply_torch_patches() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not enabling subfunction do we need to do the monkey patching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not enabling subfunction then monkey patching is not required but doing this won't harm execution, we have checked the generation w/o subfunction and monkey patching, though we can put a condition for this too.
| _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] | ||
| _onnx_transforms = [ | ||
| FP16ClipTransform, | ||
| CustomOpTransform, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to apply the CustomOpTransform again after export?
|
It might be more intuitive and flexible to have a dedicated flag in the export configuration—something like |
Hmm, I guess we have already discussed this, passing it as part of |
| if hasattr(onnx_utils, "_get_module_attributes"): | ||
| onnx_utils._get_module_attributes = _get_module_attributes | ||
|
|
||
| print("Applied torch ONNX export patches for export_modules_as_functions compatibility") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logger here.
This PR introduces support for exporting ONNX modules as functions, enabling more efficient model compilation and execution on hardware.
Changes
QEFF_USE_ONNX_FUNCTIONSto control ONNX function export behaviorEnable ONNX Functions Export
Set the environment variable before running inference:
export QEFF_USE_ONNX_FUNCTIONS=trueExport and Execute with ONNX Functions
Backward Compatibility
This feature is opt-in and requires explicit environment variable. Existing workflows remain unaffected when the flag is disabled.