-
Notifications
You must be signed in to change notification settings - Fork 624
For nz unset in bf16&fp16 #4495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
For nz unset in bf16&fp16 #4495
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
36e30a0 to
18bd5fc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
这个 PR 的目标是在浮点(bf16/fp16)场景下禁用 NZ 格式,这是通过修改 is_enable_nz 函数并更新其所有调用点来实现的。这个方向是正确的。
然而,在 vllm_ascend/utils.py 中对 is_enable_nz 函数的修改引入了一个严重的问题。该函数现在会修改一个全局变量 _ENABLE_NZ,这会导致其行为依赖于调用历史,从而产生不确定的结果。例如,一旦使用 float16 或 bfloat16 类型的 dtype 调用该函数,全局的 _ENABLE_NZ 标志就会被永久设置为 False,这将影响后续对其他数据类型(如 int8)的调用。
我提供了一个修复建议,以解决这个全局状态被错误修改的问题。请查看具体的审查评论。
| if dtype in [torch.float16, torch.bfloat16]: | ||
| _ENABLE_NZ = False | ||
|
|
||
| return _ENABLE_NZ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数对全局变量 _ENABLE_NZ 的修改方式存在严重问题。当使用 torch.float16 或 torch.bfloat16 类型的 dtype 调用此函数时,它会将全局变量 _ENABLE_NZ 设置为 False。由于 _ENABLE_NZ 是全局变量,并且只在为 None 时初始化一次,这个修改将是持久的。
这会导致后续对 is_enable_nz 的调用(即使 dtype 不同,例如 torch.int8)都会返回 False,这可能与预期的行为不符,并引入了依赖于调用顺序的隐藏 bug。
为了修正这个问题,我们应该避免在函数中修改全局变量 _ENABLE_NZ。正确的做法应该是,在函数内部根据 dtype 决定返回值,而不改变全局状态。
建议修改如下:
def is_enable_nz(dtype: Optional[torch.dtype] = torch.int8,
vllm_config: Optional[VllmConfig] = None) -> bool:
global _ENABLE_NZ
if _ENABLE_NZ is None:
if not vllm_config:
raise ValueError(
"vllm_config must be provided when _ENABLE_NZ is None")
_ENABLE_NZ = envs_ascend.VLLM_ASCEND_ENABLE_NZ and vllm_config.model_config.hf_config.model_type != "qwen3_next"
if dtype in [torch.float16, torch.bfloat16]:
return False
return _ENABLE_NZ| if dtype in [torch.float16, torch.bfloat16]: | |
| _ENABLE_NZ = False | |
| return _ENABLE_NZ | |
| if dtype in [torch.float16, torch.bfloat16]: | |
| return False | |
| return _ENABLE_NZ |
Signed-off-by: 刘哲续 <[email protected]>
Signed-off-by: 刘哲续 <[email protected]>
Signed-off-by: 刘哲续 <[email protected]>
2b59871 to
50f58d7
Compare
What this PR does / why we need it?
disable NZ for float weight case. This is only a quick fix for dev branch.
For main branch, we'll consider more case to make it more common.
Does this PR introduce any user-facing change?
How was this patch tested?
qwen2.5 32B
