Skip to content

Conversation

hertschuh
Copy link
Collaborator

@hertschuh hertschuh commented Oct 1, 2025

overwrite_with_gradient

overwrite_with_gradient would be ineffective on JAX in real-world conditions, i.e. within model.fit.

This is because in the training loop, stateless_apply is passed trainable_variables as arrays containing the values of the trainable variables, not the variables themselves. Instead, we have to inspect the variables.

apply with gradients only

apply(grads) without the trainable_variables argument passed in would not apply anything.

This is because the code uses self._trainable_variables. But this was an empty array for LossScaleOptimizer. This was fixed by adding super().build(...).

unsupported optimizer arguments

Also fail when other arguments from the base optimizer are passed to LossScaleOptimizer.__init__ since they are not actually supported. They are also no longer returned by get_config.

Copy link
Contributor

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses critical bugs within the LossScaleOptimizer concerning its handling of variables, particularly in scenarios involving JAX and when trainable_variables are not explicitly supplied to the apply method. The changes ensure that the optimizer correctly identifies and manipulates the variables managed by its inner optimizer, thereby enhancing the robustness and accuracy of mixed-precision training workflows.

Highlights

  • LossScaleOptimizer Variable Handling Fix: Corrected LossScaleOptimizer to properly reference trainable_variables from its inner_optimizer when applying gradients. This resolves issues where overwrite_with_gradient was ineffective on JAX and apply(grads) failed without explicit trainable_variables.
  • stateless_apply Argument Clarification: Updated the stateless_apply method's docstring in BaseOptimizer to explicitly state that optimizer_variables and trainable_variables arguments should be native tensors (values) rather than keras.Variable objects, improving clarity and correctness for stateless operations.
  • Enhanced Test Coverage: Introduced a new test case (test_apply_with_no_vars) to validate the LossScaleOptimizer.apply method when trainable_variables are not explicitly provided. Additionally, existing tests were updated to correctly pass variable values to stateless_apply, aligning with the clarified argument types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces important bug fixes for variable handling within the LossScaleOptimizer, particularly for JAX and stateless execution scenarios. The core logic changes correctly reference the inner optimizer's variables, resolving issues where overwrite_with_gradient was ineffective and apply(grads) would fail. The accompanying test updates and the new test case are valuable additions that improve test correctness and coverage. My review focuses on a couple of points: improving variable naming in tests to avoid shadowing Python built-ins, and highlighting the use of private attributes, which goes against the provided style guide's principle of encapsulation. Overall, the functional changes are solid.

Comment on lines +38 to +43
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(vars)
optimizer.apply(grads)
self.assertAllClose(
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable name vars shadows the built-in Python function vars(). It's a good practice to avoid this to prevent confusion and potential bugs. Consider renaming it to something more descriptive like model_variables or trainable_variables to improve clarity.1

Suggested change
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(vars)
optimizer.apply(grads)
self.assertAllClose(
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
)
model_variables = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(model_variables)
optimizer.apply(grads)
self.assertAllClose(
model_variables, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
)

Style Guide References

Footnotes

  1. Avoid overly generic names and follow Python naming conventions. Shadowing built-in functions is discouraged.

@codecov-commenter
Copy link

codecov-commenter commented Oct 1, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.60%. Comparing base (5ae5503) to head (3c41697).
⚠️ Report is 1 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21706   +/-   ##
=======================================
  Coverage   82.59%   82.60%           
=======================================
  Files         572      572           
  Lines       58327    58328    +1     
  Branches     9131     9134    +3     
=======================================
+ Hits        48177    48181    +4     
+ Misses       7818     7817    -1     
+ Partials     2332     2330    -2     
Flag Coverage Δ
keras 82.40% <100.00%> (+<0.01%) ⬆️
keras-jax 63.31% <100.00%> (+0.01%) ⬆️
keras-numpy 57.66% <100.00%> (+0.01%) ⬆️
keras-openvino 34.31% <0.00%> (-0.01%) ⬇️
keras-tensorflow 64.06% <100.00%> (+0.01%) ⬆️
keras-torch 63.64% <100.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@hertschuh
Copy link
Collaborator Author

@cantonios can you review?

optimizer.build(vars)
vars, _ = optimizer.stateless_apply(
optimizer.variables, grads, vars
[v.value for v in optimizer.variables],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is because of that JAX deprecation issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How prevalent is this pattern across keras? If we can no longer rely on the __jax_array__ built-in to auto convert, maybe there should be a helper function like optimizer.variable_values that unwraps the variables?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I uncovered this as part of the __jax_array__ deprecation. See #21702

However, I discovered in the process that we're testing stateless_apply the wrong way. stateless_apply takes tensors, not variables (unlike the names suggest). So this change of the unit tests is correct, and is independent of the __jax_array__ deprecation.

This pattern of unwrapping variables was only added in this test file (see #21702 ). This pattern does exists to some extend, but in the JAX trainer, not the optimizers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this all looks better.

`overwrite_with_gradient` would be ineffective on JAX in real-world conditions, i.e. within `model.fit`.

This is because in the training loop, `stateless_apply` is passed `trainable_variables` as arrays containing the values of the trainable variables, not the variables themselves. Instead, we have to inspect the variables.

`apply(grads)` without the `trainable_variables` argument passed in would not apply anything.

This is because the code uses `self._trainable_variables`. But this was an empty array for `LossScaleOptimizer`. This was fixed by adding `super().build(...)`.

Also fail when other arguments from the base optimizer are passed to `LossScaleOptimizer.__init__` since they are not actually supported. They are also no longer returned by `get_config`.
@hertschuh hertschuh force-pushed the loss_scale_optimizer branch from cd630c6 to 3c41697 Compare October 1, 2025 23:26
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 2, 2025
@hertschuh hertschuh merged commit f279e93 into keras-team:master Oct 2, 2025
12 checks passed
@hertschuh hertschuh deleted the loss_scale_optimizer branch October 2, 2025 19:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready to pull Ready to be merged into the codebase size:M
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants