Skip to content

Conversation

@vbaddi
Copy link
Contributor

@vbaddi vbaddi commented Nov 9, 2025

This PR introduces support for exporting ONNX modules as functions, enabling more efficient model compilation and execution on hardware.

Changes

  • Added new environment variable QEFF_USE_ONNX_FUNCTIONS to control ONNX function export behavior
  • Integrated ONNX function export capability into the inference pipeline

Enable ONNX Functions Export

Set the environment variable before running inference:

export QEFF_USE_ONNX_FUNCTIONS=true

Export and Execute with ONNX Functions

python -m QEfficient.cloud.infer \
  --model-name gpt2 \
  --num-cores 16 \
  --device-group "[0]" \
  --prompt "My name is" \
  --num-layers 2

Backward Compatibility

This feature is opt-in and requires explicit environment variable. Existing workflows remain unaffected when the flag is disabled.

- Auto-detect decoder layers for export_modules_as_functions based on model type
- Add CustomOpTransform to dynamically register and include custom ops (CustomRMSNorm, CtxGather, CtxScatter)
- Fix invalid INT32_MAX indices in ONNX runtime by replacing with 0
- Support ONNX functions export via QEFF_USE_ONNX_FUNCTIONS env var
- Handle rope_scaling None values gracefully for Gemma3

Signed-off-by: vbaddi <[email protected]>
Signed-off-by: Vinayak Baddi <[email protected]>
@vbaddi vbaddi marked this pull request as draft November 9, 2025 11:45
@vbaddi vbaddi changed the title Feat: Add ONNX Sub Functions Export Feature WIP: Feat: Add ONNX Sub Functions Export Feature Nov 9, 2025
"""
transformed = False
onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
temp_onnx_path = kwargs.get("temp_onnx_path", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make it as a mandiatory argument? and onnx_base_dir is unused here

:param temp_onnx_path: Path to save the slimmed ONNX model.
"""
transformed = False
onnx_slim_transform = True # kwargs.get("enable_onnx_slim_transform", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if OnnxSlimTransform is called do you need to again have a flag for onnx_slim_transform = True? and then check it on line 130? expectation should be to apply the onnxslimtransform right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove it from here. There is a flag called "enable_onnx_slim_transform" lets users decide whether to enable ONNX Slim. We can add a condition in modeling_auto so that this transform is included in the _onnx_transform list only when the flag is enabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this change with GPTOSS, it fails in the onnxslim transform. Discussed with VB that this doesn't help us much.
Lets not add extra package dependency if it has limited use.
Let's remove onnxslim

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide the error log? I think there should be a 5% performance gain with onnxslim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inisis thanks. you mean ~5% gain in perf.? w/onnxslim

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about GPT OSS but for Qwen 2.5 VL we observed that onnxslim removes identical nodes which lead to creation of dummy nodes.

@abhishek-singh591 Hi, what's that dummy nodes, it should not be created by onnxslim, the removes of identical nodes is generally known as CSE, it reduces extra computation, I thinks it's very useful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nodes left behind after removing identical can be considered dummy/orphan nodes. These are nodes that were originally connected as outputs to the identical nodes but, after CSE they no longer have valid connections. Ideally, CSE should rewire the inputs and outputs properly so that no orphan nodes remain right?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, onnxslim will remove those dummy nodes automatically.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, onnxslim will remove those dummy nodes automatically.

Actually, it's not doing and we also don’t want to delete those nodes. Before removing identical nodes through CSE, it should connect the input of the identity node directly to its output, ensuring the graph remains valid.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, onnxslim will remove those dummy nodes automatically.

Actually, it's not doing and we also don’t want to delete those nodes. Before removing identical nodes through CSE, it should connect the input of the identity node directly to its output, ensuring the graph remains valid.

Really, can you provide me an example, many thanks.

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))

if hasattr(config, "rope_scaling") and "factor" in config.rope_scaling:
if hasattr(config, "rope_scaling") and config.rope_scaling is not None and "factor" in config.rope_scaling:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change part of ONNX Sub Functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but the correct modeling representation changes.

example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32))
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
output_names.append(f"past_{kv}.{i}_RetainedState")
output_names.append(f"past_{kv}.{i}_InternalRetainedState")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we are renaming it? if we are renaming _RetainedState to _InternalRetainedState wouldnt the chages need to added on text_generation_inference and other places we are skipping the bufferes? Even if we are not enabling the subfunction this would impact regular execution

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX_EXPORT_EXAMPLE_FBS = 4
ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep
ONNX_EXPORT_OPSET = 13
ONNX_EXPORT_OPSET = 17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some test on opset 17 is still ongoing @quic-hemagnih are we good to merge opset 17 changes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but export_module_as_function has a hard constraint that opset must be >=15.


# Apply patches
# TODO: Find a better way to do this, this is temp. fix.
apply_torch_patches()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not enabling subfunction do we need to do the monkey patching?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not enabling subfunction then monkey patching is not required but doing this won't harm execution, we have checked the generation w/o subfunction and monkey patching, though we can put a condition for this too.

_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
_onnx_transforms = [
FP16ClipTransform,
CustomOpTransform,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to apply the CustomOpTransform again after export?

@vbaddi vbaddi added the enhancement New feature or request label Nov 10, 2025
@ochougul
Copy link
Contributor

It might be more intuitive and flexible to have a dedicated flag in the export configuration—something like use_subfunctions or export_submodules_as_functions
Using an environment variable makes it harder to switch between the two approaches dynamically, especially during development or testing. A flag would offer clearer intent and better usability.
We can add this flag to the export API in auto classes

@vbaddi
Copy link
Contributor Author

vbaddi commented Nov 11, 2025

It might be more intuitive and flexible to have a dedicated flag in the export configuration—something like use_subfunctions or export_submodules_as_functions Using an environment variable makes it harder to switch between the two approaches dynamically, especially during development or testing. A flag would offer clearer intent and better usability. We can add this flag to the export API in auto classes

Hmm, I guess we have already discussed this, passing it as part of .export() doesn't make sense, since there are cache module changes required. We can do it as part of .pre_trained().

if hasattr(onnx_utils, "_get_module_attributes"):
onnx_utils._get_module_attributes = _get_module_attributes

print("Applied torch ONNX export patches for export_modules_as_functions compatibility")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logger here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants