diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index c3ecdd2baab..ad94a85e9e6 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -631,6 +631,20 @@ def _backend_increment_gradient_accumulators(self, grads, acc_grads): g_acc.assign(n_g_acc) def stateless_apply(self, optimizer_variables, grads, trainable_variables): + """Stateless version of `apply` that returns modified variables. + + Args: + optimizer_variables: list of tensors containing the current values + for the optimizer variables. These are native tensors and not + `keras.Variable`s. + grads: list of gradients to apply. + trainable_variables: list of tensors containing the current values + for the model variables. These are native tensors and not + `keras.Variable`s. + + Returns: A tuple containing two list of tensors, the updated + `trainable_variables` and the updated `optimizer_variables`. + """ self._check_super_called() if not self.built: diff --git a/keras/src/optimizers/loss_scale_optimizer.py b/keras/src/optimizers/loss_scale_optimizer.py index 1b9945c4157..babeafffcd7 100644 --- a/keras/src/optimizers/loss_scale_optimizer.py +++ b/keras/src/optimizers/loss_scale_optimizer.py @@ -48,6 +48,7 @@ def __init__( inner_optimizer, initial_scale=2.0**15, dynamic_growth_steps=2000, + name=None, **kwargs, ): if not kwargs.pop("dynamic", True): @@ -56,7 +57,42 @@ def __init__( "Instead, simply set `loss_scale_factor` directly on the " "`inner_optimizer`." ) - super().__init__(learning_rate=0.0, **kwargs) + + # Backwards compatibility code for deserialization. + # LossScaleOptimizer used to return all these parameters in `get_config` + # from `super.get_config` even though they are all non-functional. We + # no longer let user set them, but we have to allow the default values + # to be passed during deserialization to support older models. + base_optimizer_defaults = { + "weight_decay": None, + "clipnorm": None, + "global_clipnorm": None, + "clipvalue": None, + "use_ema": False, + "ema_momentum": 0.99, + "ema_overwrite_frequency": None, + "loss_scale_factor": None, + "gradient_accumulation_steps": None, + } + for arg_name, default_value in base_optimizer_defaults.items(): + if arg_name not in kwargs: + continue + arg_value = kwargs.pop(arg_name) + if ( + default_value is None and arg_value is not None + ) or arg_value != default_value: + raise ValueError( + f"LossScaleOptimizer does not support `{arg_name}`. " + f"Instead, set `{arg_name}` on the `inner_optimizer`." + ) + + if kwargs: + raise ValueError( + "LossScaleOptimizer does not support arguments: " + f"`{'`, `'.join(kwargs.keys())}`." + ) + + super().__init__(learning_rate=0.0, name=name) self.inner_optimizer = inner_optimizer self.initial_scale = initial_scale self.dynamic_growth_steps = dynamic_growth_steps @@ -81,7 +117,7 @@ def build(self, var_list): name="dynamic_scale", ) self.inner_optimizer.build(var_list) - self.built = True + super().build(var_list) @property def variables(self): @@ -136,7 +172,7 @@ def increment(): g if g is None or self._overwrite_variable_with_gradient(v) else ops.divide(g, scale) - for g, v in zip(grads, trainable_variables) + for g, v in zip(grads, self._trainable_variables) ] ( new_trainable_variables, @@ -284,19 +320,16 @@ def finalize_variable_values(self, var_list): self.inner_optimizer.finalize_variable_values(var_list) def get_config(self): - config = super().get_config() + # Do not use super().get_config() as only "name" is supported. inner_optimizer_config = serialization_lib.serialize_keras_object( self.inner_optimizer ) - config.update( - { - "inner_optimizer": inner_optimizer_config, - "initial_scale": self.initial_scale, - "dynamic_growth_steps": self.dynamic_growth_steps, - } - ) - del config["learning_rate"] - return config + return { + "name": self.name, + "inner_optimizer": inner_optimizer_config, + "initial_scale": self.initial_scale, + "dynamic_growth_steps": self.dynamic_growth_steps, + } @classmethod def from_config(cls, config, custom_objects=None): diff --git a/keras/src/optimizers/loss_scale_optimizer_test.py b/keras/src/optimizers/loss_scale_optimizer_test.py index c053d96787f..d707ad765f3 100644 --- a/keras/src/optimizers/loss_scale_optimizer_test.py +++ b/keras/src/optimizers/loss_scale_optimizer_test.py @@ -29,6 +29,19 @@ def test_config(self): optimizer = LossScaleOptimizer(inner_optimizer) self.run_class_serialization_test(optimizer) + def test_apply_with_no_vars(self): + self._skip_test_for_stateless(False) + + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer(inner_optimizer) + grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale] + vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] + optimizer.build(vars) + optimizer.apply(grads) + self.assertAllClose( + vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4 + ) + @parameterized.named_parameters(("stateless", True), ("stateful", False)) def test_finite_step(self, stateless): self._skip_test_for_stateless(stateless) @@ -40,7 +53,9 @@ def test_finite_step(self, stateless): if stateless: optimizer.build(vars) vars, _ = optimizer.stateless_apply( - optimizer.variables, grads, vars + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], ) else: optimizer.apply(grads, vars) @@ -60,7 +75,9 @@ def test_finite_step_with_inner_loss_scale(self, stateless): if stateless: optimizer.build(vars) vars, _ = optimizer.stateless_apply( - optimizer.variables, grads, vars + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], ) else: optimizer.apply(grads, vars) @@ -79,7 +96,9 @@ def test_infinite_step(self, stateless): if stateless: optimizer.build(vars) vars, _ = optimizer.stateless_apply( - optimizer.variables, grads, vars + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], ) else: optimizer.apply(grads, vars) @@ -98,7 +117,9 @@ def test_finite_step_with_overwrite(self, stateless): if stateless: optimizer.build(vars) vars, _ = optimizer.stateless_apply( - optimizer.variables, grads, vars + [v.value for v in optimizer.variables], + grads, + [v.value for v in vars], ) else: optimizer.apply(grads, vars) @@ -112,12 +133,14 @@ def test_downscaling(self, stateless): optimizer = LossScaleOptimizer(inner_optimizer, initial_scale=400.0) vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] optimizer.build(vars) - opt_vars = optimizer.variables + opt_var_values = [v.value for v in optimizer.variables] grads = [ops.array([np.inf, np.inf, np.inf, np.inf])] for _ in range(4): if stateless: - _, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars) - for ref_v, v in zip(optimizer.variables, opt_vars): + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): ref_v.assign(v) else: optimizer.apply(grads, vars) @@ -135,12 +158,14 @@ def test_upscaling(self, stateless): ) vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] optimizer.build(vars) - opt_vars = optimizer.variables + opt_var_values = [v.value for v in optimizer.variables] grads = [ops.array([1.0, 6.0, 7.0, 2.0])] for _ in range(8): if stateless: - _, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars) - for ref_v, v in zip(optimizer.variables, opt_vars): + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): ref_v.assign(v) else: optimizer.apply(grads, vars) @@ -154,16 +179,104 @@ def test_iterations_update(self, stateless): optimizer = LossScaleOptimizer(inner_optimizer) vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])] optimizer.build(vars) - opt_vars = optimizer.variables + opt_var_values = [v.value for v in optimizer.variables] grads = [ops.array([1.0, 6.0, 7.0, 2.0])] self.assertEqual(optimizer.iterations.value, 0) for i in range(3): if stateless: - _, opt_vars = optimizer.stateless_apply(opt_vars, grads, vars) - for ref_v, v in zip(optimizer.variables, opt_vars): + _, opt_var_values = optimizer.stateless_apply( + opt_var_values, grads, [v.value for v in vars] + ) + for ref_v, v in zip(optimizer.variables, opt_var_values): ref_v.assign(v) else: optimizer.apply(grads, vars) self.assertEqual(optimizer.iterations.value, i + 1) + + def test_serialization(self): + inner_optimizer = SGD(learning_rate=0.5) + optimizer = LossScaleOptimizer( + inner_optimizer, + initial_scale=3.0, + dynamic_growth_steps=2, + name="test_opt", + ) + config = optimizer.get_config() + self.assertLen(config, 4) + self.assertEqual(config["name"], "test_opt") + self.assertEqual(config["initial_scale"], 3.0) + self.assertEqual(config["dynamic_growth_steps"], 2) + self.assertIn("inner_optimizer", config) + LossScaleOptimizer.from_config(config) + + def test_init_dynamic_arg(self): + inner_optimizer = SGD(learning_rate=0.5) + + # dynamic=True is supported + LossScaleOptimizer(inner_optimizer, dynamic=True) + + # dynamic=False is not supported + with self.assertRaisesRegex(ValueError, "set `loss_scale_factor`"): + LossScaleOptimizer(inner_optimizer, dynamic=False) + + def test_init_unsupported_arg(self): + inner_optimizer = SGD(learning_rate=0.5) + with self.assertRaisesRegex(ValueError, "arguments: `foo`, `bar`"): + LossScaleOptimizer(inner_optimizer, foo=True, bar=3) + + @parameterized.named_parameters( + ("weight_decay", "weight_decay", 0.5), + ("clipnorm", "clipnorm", 0.5), + ("global_clipnorm", "global_clipnorm", 0.5), + ("clipvalue", "clipvalue", 0.5), + ("use_ema", "use_ema", True), + ("ema_momentum", "ema_momentum", 0.5), + ("ema_overwrite_frequency", "ema_overwrite_frequency", 2), + ("loss_scale_factor", "loss_scale_factor", 0.5), + ("gradient_accumulation_steps", "gradient_accumulation_steps", 2), + ) + def test_init_base_optimizer_unsupported_args(self, arg_name, arg_value): + inner_optimizer = SGD(learning_rate=0.5) + with self.assertRaisesRegex(ValueError, "on the `inner_optimizer`"): + LossScaleOptimizer(inner_optimizer, **{arg_name: arg_value}) + + def test_deserialization_backwards_compatibility(self): + # Test deserializing with a config that has all the unsupported + # arguments from the base optimizer (which are no longer serialized) + config = { + "name": "loss_scale_optimizer", + "weight_decay": None, + "clipnorm": None, + "global_clipnorm": None, + "clipvalue": None, + "use_ema": False, + "ema_momentum": 0.99, + "ema_overwrite_frequency": None, + "loss_scale_factor": None, + "gradient_accumulation_steps": None, + "inner_optimizer": { + "module": "keras.optimizers", + "class_name": "SGD", + "config": { + "name": "SGD", + "learning_rate": 0.5, + "weight_decay": None, + "clipnorm": None, + "global_clipnorm": None, + "clipvalue": None, + "use_ema": False, + "ema_momentum": 0.99, + "ema_overwrite_frequency": None, + "loss_scale_factor": None, + "gradient_accumulation_steps": None, + "momentum": 0.0, + "nesterov": False, + }, + "registered_name": None, + }, + "initial_scale": 2.0, + "dynamic_growth_steps": 2, + } + LossScaleOptimizer.from_config(config)