-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Bug fixes with variable handling in LossScaleOptimizer
.
#21706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses critical bugs within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces important bug fixes for variable handling within the LossScaleOptimizer
, particularly for JAX and stateless execution scenarios. The core logic changes correctly reference the inner optimizer's variables, resolving issues where overwrite_with_gradient
was ineffective and apply(grads)
would fail. The accompanying test updates and the new test case are valuable additions that improve test correctness and coverage. My review focuses on a couple of points: improving variable naming in tests to avoid shadowing Python built-ins, and highlighting the use of private attributes, which goes against the provided style guide's principle of encapsulation. Overall, the functional changes are solid.
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] | ||
optimizer.build(vars) | ||
optimizer.apply(grads) | ||
self.assertAllClose( | ||
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name vars
shadows the built-in Python function vars()
. It's a good practice to avoid this to prevent confusion and potential bugs. Consider renaming it to something more descriptive like model_variables
or trainable_variables
to improve clarity.1
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] | |
optimizer.build(vars) | |
optimizer.apply(grads) | |
self.assertAllClose( | |
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4 | |
) | |
model_variables = [backend.Variable([1.0, 2.0, 3.0, 4.0])] | |
optimizer.build(model_variables) | |
optimizer.apply(grads) | |
self.assertAllClose( | |
model_variables, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4 | |
) |
Style Guide References
Footnotes
-
Avoid overly generic names and follow Python naming conventions. Shadowing built-in functions is discouraged. ↩
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21706 +/- ##
=======================================
Coverage 82.59% 82.60%
=======================================
Files 572 572
Lines 58327 58328 +1
Branches 9131 9134 +3
=======================================
+ Hits 48177 48181 +4
+ Misses 7818 7817 -1
+ Partials 2332 2330 -2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@cantonios can you review? |
optimizer.build(vars) | ||
vars, _ = optimizer.stateless_apply( | ||
optimizer.variables, grads, vars | ||
[v.value for v in optimizer.variables], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is because of that JAX deprecation issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How prevalent is this pattern across keras? If we can no longer rely on the __jax_array__
built-in to auto convert, maybe there should be a helper function like optimizer.variable_values
that unwraps the variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I uncovered this as part of the __jax_array__
deprecation. See #21702
However, I discovered in the process that we're testing stateless_apply
the wrong way. stateless_apply
takes tensors, not variables (unlike the names suggest). So this change of the unit tests is correct, and is independent of the __jax_array__
deprecation.
This pattern of unwrapping variables was only added in this test file (see #21702 ). This pattern does exists to some extend, but in the JAX trainer, not the optimizers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this all looks better.
`overwrite_with_gradient` would be ineffective on JAX in real-world conditions, i.e. within `model.fit`. This is because in the training loop, `stateless_apply` is passed `trainable_variables` as arrays containing the values of the trainable variables, not the variables themselves. Instead, we have to inspect the variables. `apply(grads)` without the `trainable_variables` argument passed in would not apply anything. This is because the code uses `self._trainable_variables`. But this was an empty array for `LossScaleOptimizer`. This was fixed by adding `super().build(...)`. Also fail when other arguments from the base optimizer are passed to `LossScaleOptimizer.__init__` since they are not actually supported. They are also no longer returned by `get_config`.
cd630c6
to
3c41697
Compare
overwrite_with_gradient
overwrite_with_gradient
would be ineffective on JAX in real-world conditions, i.e. withinmodel.fit
.This is because in the training loop,
stateless_apply
is passedtrainable_variables
as arrays containing the values of the trainable variables, not the variables themselves. Instead, we have to inspect the variables.apply with gradients only
apply(grads)
without thetrainable_variables
argument passed in would not apply anything.This is because the code uses
self._trainable_variables
. But this was an empty array forLossScaleOptimizer
. This was fixed by addingsuper().build(...)
.unsupported optimizer arguments
Also fail when other arguments from the base optimizer are passed to
LossScaleOptimizer.__init__
since they are not actually supported. They are also no longer returned byget_config
.