-
Notifications
You must be signed in to change notification settings - Fork 617
[feature]npugraph_ex #4499
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[feature]npugraph_ex #4499
Conversation
Signed-off-by: chencangtao <[email protected]>
Signed-off-by: chencangtao <[email protected]>
# Conflicts: # vllm_ascend/worker/model_runner_v1.py
Signed-off-by: chencangtao <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for a new NPU backend optimization path, enabled by the enable_npugraph_ex_optimize configuration. It patches vllm.compilation.compiler_interface.EagerAdaptor to use torchair for graph compilation, adds corresponding tests, and includes necessary adjustments in the model runner and rotary embedding operations. The changes appear to correctly implement the new compilation path. My main feedback is regarding code duplication in vllm_ascend/worker/model_runner_v1.py, where a workaround for handling model output is repeated. Refactoring this into a helper function would improve maintainability.
| # Sometimes, after the model is compiled through the AOT backend, | ||
| # the model output may become a list containing only one Tensor object. | ||
| if isinstance(hidden_states, list) and \ | ||
| len(hidden_states) == 1 and \ | ||
| isinstance(hidden_states[0], torch.Tensor): | ||
| hidden_states = hidden_states[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic to unwrap a singleton list from hidden_states is also present in profile_run on lines 3102-3107. Duplicating this workaround for AOT backend output inconsistencies increases maintenance overhead and risk of future bugs if one location is updated and the other is missed. To improve maintainability, this logic should be extracted into a private helper method.
For example:
def _unwrap_singleton_list(self, output: Any) -> Any:
"""Unwraps the output if it is a list containing a single tensor."""
if isinstance(output, list) and len(output) == 1 and isinstance(output[0], torch.Tensor):
return output[0]
return outputYou could then replace the duplicated blocks with a call to self._unwrap_singleton_list(hidden_states).
Signed-off-by: chencangtao <[email protected]>
Signed-off-by: chencangtao <[email protected]>
Signed-off-by: chencangtao <[email protected]>
Signed-off-by: chencangtao <[email protected]>
Signed-off-by: chencangtao <[email protected]>
Signed-off-by: chencangtao <[email protected]>
Signed-off-by: chencangtao <[email protected]>
| config = torchair.CompilerConfig() | ||
| config.debug.run_eagerly = True | ||
| config.experimental_config.aclgraph._aclnn_static_shape_kernel = True | ||
| config.debug.aclgraph.disable_mempool_reuse_in_same_fx = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete this
| get_flashcomm2_oproj_tp_size_and_validate_config | ||
| self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( | ||
| self, vllm_config) | ||
| self.enable_npugraph_ex_optimize = additional_config.get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verification is added. This function can be enabled only in fullgraph and full_decode_only scenarios.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?