-
Notifications
You must be signed in to change notification settings - Fork 88
fix(transformers/ut): fix compute_diffs Division by zero check #1351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @JIJIARONGjijiarong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a potential runtime error in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a potential division-by-zero error in the compute_diffs
test utility function. The original logic for adding a small epsilon to the denominator was flawed, and this change rectifies it by checking if the norm of the PyTorch output tensor p
is zero. To further enhance numerical stability, I've suggested using np.isclose
for the floating-point comparison, which is a more robust practice than direct equality checks.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
||
# adaption for tensor with all zeros element | ||
eps = 1e-9 if np.all(m.astype(np.float32) == 0) and np.all(p.astype(np.float32) == 0) else 0 | ||
eps = 1e-9 if np.isclose(np.linalg.norm(p), 0, atol=1e-9) else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
两个问题哈,这里为啥用isclose,另外逻辑上是否等价于 eps = (x+1e-9) if x < 1e-9 else x,这是基于什么考虑的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本来打算修改成 直接判断 np.linalg.norm(p) == 0, 来防止 除0。按照原来的写法 出现 p 不为0 值非常小 但norm p 结果为0,eps 因通过 np.all(m.astype(np.float32) == 0) and np.all(p.astype(np.float32) == 0) 条件而没有赋值导致 除0 的情况。
isclose 采取了 gemini 的建议:
For floating-point comparisons, using
np.isclose
is generally more robust than a direct comparison with== 0
. This change will handle cases where the norm is extremely small but not exactly zero due to precision limitations, preventing potential numerical instability. Settingatol
explicitly to1e-9
makes the tolerance consistent with theeps
value used.
修改后逻辑没问题。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本来打算修改成 直接判断 np.linalg.norm(p) == 0, 来防止 除0。按照原来的写法 出现 p 不为0 值非常小 但norm p 结果为0,eps 因通过 np.all(m.astype(np.float32) == 0) and np.all(p.astype(np.float32) == 0) 条件而没有赋值导致 除0 的情况。
isclose 采取了 gemini 的建议:
For floating-point comparisons, usingnp.isclose
is generally more robust than a direct comparison with== 0
. This change will handle cases where the norm is extremely small but not exactly zero due to precision limitations, preventing potential numerical instability. Settingatol
explicitly to1e-9
makes the tolerance consistent with theeps
value used.修改后逻辑没问题。
这里是两个问题哈 一个是为什么用isclose接口 另一个为什么用这个逻辑
核心点是这个isclose额外引入了刚才说的那个“容忍度” 在一些情况下可能不一定合理
我的建议是可以尝试 eps = 1e-9 if norm(p) == 0. else 0. 或者 eps = 1e-9,(ps: 要么全加 要么只在极端环境加)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果要考虑不同dtype的溢出的场景,可以参考gemini这里的建议,但是eps可能也需要根据具体的precision确定
pr的描述也可以修改为最后实际选择的方式哈 |
Fix
fix compute_diffs Division by zero check
eps = 1e-9 if np.all(m.astype(np.float32) == 0) and np.all(p.astype(np.float32) == 0) else 0
eps = 1e-9 if np.linalg.norm(p) == 0 else 0