Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions keras/src/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,20 @@ def _backend_increment_gradient_accumulators(self, grads, acc_grads):
g_acc.assign(n_g_acc)

def stateless_apply(self, optimizer_variables, grads, trainable_variables):
"""Stateless version of `apply` that returns modified variables.

Args:
optimizer_variables: list of tensors containing the current values
for the optimizer variables. These are native tensors and not
`keras.Variable`s.
grads: list of gradients to apply.
trainable_variables: list of tensors containing the current values
for the model variables. These are native tensors and not
`keras.Variable`s.

Returns: A tuple containing two list of tensors, the updated
`trainable_variables` and the updated `optimizer_variables`.
"""
self._check_super_called()

if not self.built:
Expand Down
59 changes: 46 additions & 13 deletions keras/src/optimizers/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
inner_optimizer,
initial_scale=2.0**15,
dynamic_growth_steps=2000,
name=None,
**kwargs,
):
if not kwargs.pop("dynamic", True):
Expand All @@ -56,7 +57,42 @@ def __init__(
"Instead, simply set `loss_scale_factor` directly on the "
"`inner_optimizer`."
)
super().__init__(learning_rate=0.0, **kwargs)

# Backwards compatibility code for deserialization.
# LossScaleOptimizer used to return all these parameters in `get_config`
# from `super.get_config` even though they are all non-functional. We
# no longer let user set them, but we have to allow the default values
# to be passed during deserialization to support older models.
base_optimizer_defaults = {
"weight_decay": None,
"clipnorm": None,
"global_clipnorm": None,
"clipvalue": None,
"use_ema": False,
"ema_momentum": 0.99,
"ema_overwrite_frequency": None,
"loss_scale_factor": None,
"gradient_accumulation_steps": None,
}
for arg_name, default_value in base_optimizer_defaults.items():
if arg_name not in kwargs:
continue
arg_value = kwargs.pop(arg_name)
if (
default_value is None and arg_value is not None
) or arg_value != default_value:
raise ValueError(
f"LossScaleOptimizer does not support `{arg_name}`. "
f"Instead, set `{arg_name}` on the `inner_optimizer`."
)

if kwargs:
raise ValueError(
"LossScaleOptimizer does not support arguments: "
f"`{'`, `'.join(kwargs.keys())}`."
)

super().__init__(learning_rate=0.0, name=name)
self.inner_optimizer = inner_optimizer
self.initial_scale = initial_scale
self.dynamic_growth_steps = dynamic_growth_steps
Expand All @@ -81,7 +117,7 @@ def build(self, var_list):
name="dynamic_scale",
)
self.inner_optimizer.build(var_list)
self.built = True
super().build(var_list)

@property
def variables(self):
Expand Down Expand Up @@ -136,7 +172,7 @@ def increment():
g
if g is None or self._overwrite_variable_with_gradient(v)
else ops.divide(g, scale)
for g, v in zip(grads, trainable_variables)
for g, v in zip(grads, self._trainable_variables)
]
(
new_trainable_variables,
Expand Down Expand Up @@ -284,19 +320,16 @@ def finalize_variable_values(self, var_list):
self.inner_optimizer.finalize_variable_values(var_list)

def get_config(self):
config = super().get_config()
# Do not use super().get_config() as only "name" is supported.
inner_optimizer_config = serialization_lib.serialize_keras_object(
self.inner_optimizer
)
config.update(
{
"inner_optimizer": inner_optimizer_config,
"initial_scale": self.initial_scale,
"dynamic_growth_steps": self.dynamic_growth_steps,
}
)
del config["learning_rate"]
return config
return {
"name": self.name,
"inner_optimizer": inner_optimizer_config,
"initial_scale": self.initial_scale,
"dynamic_growth_steps": self.dynamic_growth_steps,
}

@classmethod
def from_config(cls, config, custom_objects=None):
Expand Down
139 changes: 126 additions & 13 deletions keras/src/optimizers/loss_scale_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def test_config(self):
optimizer = LossScaleOptimizer(inner_optimizer)
self.run_class_serialization_test(optimizer)

def test_apply_with_no_vars(self):
self._skip_test_for_stateless(False)

inner_optimizer = SGD(learning_rate=0.5)
optimizer = LossScaleOptimizer(inner_optimizer)
grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale]
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(vars)
optimizer.apply(grads)
self.assertAllClose(
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
)
Comment on lines +38 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable name vars shadows the built-in Python function vars(). It's a good practice to avoid this to prevent confusion and potential bugs. Consider renaming it to something more descriptive like model_variables or trainable_variables to improve clarity.1

Suggested change
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(vars)
optimizer.apply(grads)
self.assertAllClose(
vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
)
model_variables = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(model_variables)
optimizer.apply(grads)
self.assertAllClose(
model_variables, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4
)

Style Guide References

Footnotes

  1. Avoid overly generic names and follow Python naming conventions. Shadowing built-in functions is discouraged.


@parameterized.named_parameters(("stateless", True), ("stateful", False))
def test_finite_step(self, stateless):
self._skip_test_for_stateless(stateless)
Expand All @@ -40,7 +53,9 @@ def test_finite_step(self, stateless):
if stateless:
optimizer.build(vars)
vars, _ = optimizer.stateless_apply(
optimizer.variables, grads, vars
[v.value for v in optimizer.variables],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is because of that JAX deprecation issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How prevalent is this pattern across keras? If we can no longer rely on the __jax_array__ built-in to auto convert, maybe there should be a helper function like optimizer.variable_values that unwraps the variables?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I uncovered this as part of the __jax_array__ deprecation. See #21702

However, I discovered in the process that we're testing stateless_apply the wrong way. stateless_apply takes tensors, not variables (unlike the names suggest). So this change of the unit tests is correct, and is independent of the __jax_array__ deprecation.

This pattern of unwrapping variables was only added in this test file (see #21702 ). This pattern does exists to some extend, but in the JAX trainer, not the optimizers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this all looks better.

grads,
[v.value for v in vars],
)
else:
optimizer.apply(grads, vars)
Expand All @@ -60,7 +75,9 @@ def test_finite_step_with_inner_loss_scale(self, stateless):
if stateless:
optimizer.build(vars)
vars, _ = optimizer.stateless_apply(
optimizer.variables, grads, vars
[v.value for v in optimizer.variables],
grads,
[v.value for v in vars],
)
else:
optimizer.apply(grads, vars)
Expand All @@ -79,7 +96,9 @@ def test_infinite_step(self, stateless):
if stateless:
optimizer.build(vars)
vars, _ = optimizer.stateless_apply(
optimizer.variables, grads, vars
[v.value for v in optimizer.variables],
grads,
[v.value for v in vars],
)
else:
optimizer.apply(grads, vars)
Expand All @@ -98,7 +117,9 @@ def test_finite_step_with_overwrite(self, stateless):
if stateless:
optimizer.build(vars)
vars, _ = optimizer.stateless_apply(
optimizer.variables, grads, vars
[v.value for v in optimizer.variables],
grads,
[v.value for v in vars],
)
else:
optimizer.apply(grads, vars)
Expand All @@ -112,12 +133,14 @@ def test_downscaling(self, stateless):
optimizer = LossScaleOptimizer(inner_optimizer, initial_scale=400.0)
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(vars)
opt_vars = optimizer.variables
opt_var_values = [v.value for v in optimizer.variables]
grads = [ops.array([np.inf, np.inf, np.inf, np.inf])]
for _ in range(4):
if stateless:
_, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars)
for ref_v, v in zip(optimizer.variables, opt_vars):
_, opt_var_values = optimizer.stateless_apply(
opt_var_values, grads, [v.value for v in vars]
)
for ref_v, v in zip(optimizer.variables, opt_var_values):
ref_v.assign(v)
else:
optimizer.apply(grads, vars)
Expand All @@ -135,12 +158,14 @@ def test_upscaling(self, stateless):
)
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(vars)
opt_vars = optimizer.variables
opt_var_values = [v.value for v in optimizer.variables]
grads = [ops.array([1.0, 6.0, 7.0, 2.0])]
for _ in range(8):
if stateless:
_, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars)
for ref_v, v in zip(optimizer.variables, opt_vars):
_, opt_var_values = optimizer.stateless_apply(
opt_var_values, grads, [v.value for v in vars]
)
for ref_v, v in zip(optimizer.variables, opt_var_values):
ref_v.assign(v)
else:
optimizer.apply(grads, vars)
Expand All @@ -154,16 +179,104 @@ def test_iterations_update(self, stateless):
optimizer = LossScaleOptimizer(inner_optimizer)
vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]
optimizer.build(vars)
opt_vars = optimizer.variables
opt_var_values = [v.value for v in optimizer.variables]
grads = [ops.array([1.0, 6.0, 7.0, 2.0])]

self.assertEqual(optimizer.iterations.value, 0)

for i in range(3):
if stateless:
_, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars)
for ref_v, v in zip(optimizer.variables, opt_vars):
_, opt_var_values = optimizer.stateless_apply(
opt_var_values, grads, [v.value for v in vars]
)
for ref_v, v in zip(optimizer.variables, opt_var_values):
ref_v.assign(v)
else:
optimizer.apply(grads, vars)
self.assertEqual(optimizer.iterations.value, i + 1)

def test_serialization(self):
inner_optimizer = SGD(learning_rate=0.5)
optimizer = LossScaleOptimizer(
inner_optimizer,
initial_scale=3.0,
dynamic_growth_steps=2,
name="test_opt",
)
config = optimizer.get_config()
self.assertLen(config, 4)
self.assertEqual(config["name"], "test_opt")
self.assertEqual(config["initial_scale"], 3.0)
self.assertEqual(config["dynamic_growth_steps"], 2)
self.assertIn("inner_optimizer", config)
LossScaleOptimizer.from_config(config)

def test_init_dynamic_arg(self):
inner_optimizer = SGD(learning_rate=0.5)

# dynamic=True is supported
LossScaleOptimizer(inner_optimizer, dynamic=True)

# dynamic=False is not supported
with self.assertRaisesRegex(ValueError, "set `loss_scale_factor`"):
LossScaleOptimizer(inner_optimizer, dynamic=False)

def test_init_unsupported_arg(self):
inner_optimizer = SGD(learning_rate=0.5)
with self.assertRaisesRegex(ValueError, "arguments: `foo`, `bar`"):
LossScaleOptimizer(inner_optimizer, foo=True, bar=3)

@parameterized.named_parameters(
("weight_decay", "weight_decay", 0.5),
("clipnorm", "clipnorm", 0.5),
("global_clipnorm", "global_clipnorm", 0.5),
("clipvalue", "clipvalue", 0.5),
("use_ema", "use_ema", True),
("ema_momentum", "ema_momentum", 0.5),
("ema_overwrite_frequency", "ema_overwrite_frequency", 2),
("loss_scale_factor", "loss_scale_factor", 0.5),
("gradient_accumulation_steps", "gradient_accumulation_steps", 2),
)
def test_init_base_optimizer_unsupported_args(self, arg_name, arg_value):
inner_optimizer = SGD(learning_rate=0.5)
with self.assertRaisesRegex(ValueError, "on the `inner_optimizer`"):
LossScaleOptimizer(inner_optimizer, **{arg_name: arg_value})

def test_deserialization_backwards_compatibility(self):
# Test deserializing with a config that has all the unsupported
# arguments from the base optimizer (which are no longer serialized)
config = {
"name": "loss_scale_optimizer",
"weight_decay": None,
"clipnorm": None,
"global_clipnorm": None,
"clipvalue": None,
"use_ema": False,
"ema_momentum": 0.99,
"ema_overwrite_frequency": None,
"loss_scale_factor": None,
"gradient_accumulation_steps": None,
"inner_optimizer": {
"module": "keras.optimizers",
"class_name": "SGD",
"config": {
"name": "SGD",
"learning_rate": 0.5,
"weight_decay": None,
"clipnorm": None,
"global_clipnorm": None,
"clipvalue": None,
"use_ema": False,
"ema_momentum": 0.99,
"ema_overwrite_frequency": None,
"loss_scale_factor": None,
"gradient_accumulation_steps": None,
"momentum": 0.0,
"nesterov": False,
},
"registered_name": None,
},
"initial_scale": 2.0,
"dynamic_growth_steps": 2,
}
LossScaleOptimizer.from_config(config)