diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 930a5cc69e58..2610790e7b77 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -47,6 +47,7 @@ from keras.src.layers.convolutional.separable_conv2d import ( SeparableConv2D as SeparableConvolution2D, ) +from keras.src.layers.core.composite_layer import CompositeLayer from keras.src.layers.core.dense import Dense from keras.src.layers.core.einsum_dense import EinsumDense from keras.src.layers.core.embedding import Embedding diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index 6dfcbed7b5b9..a7b46b388035 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -47,6 +47,7 @@ from keras.src.layers.convolutional.separable_conv2d import ( SeparableConv2D as SeparableConvolution2D, ) +from keras.src.layers.core.composite_layer import CompositeLayer from keras.src.layers.core.dense import Dense from keras.src.layers.core.einsum_dense import EinsumDense from keras.src.layers.core.embedding import Embedding diff --git a/keras/src/layers/core/composite_layer.py b/keras/src/layers/core/composite_layer.py new file mode 100644 index 000000000000..02858d4f8af1 --- /dev/null +++ b/keras/src/layers/core/composite_layer.py @@ -0,0 +1,345 @@ +import inspect +import typing + +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.layers.core.input_layer import Input +from keras.src.layers.layer import Layer +from keras.src.models.functional import compute_input_spec +from keras.src.models.functional import function_from_config +from keras.src.models.functional import run_through_graph_with_training_and_mask +from keras.src.models.functional import serialize_functional_config +from keras.src.ops.function import Function +from keras.src.utils import tracking + + +@keras_export(["keras.layers.CompositeLayer"]) +class CompositeLayer(Layer): + """Layer that encapsulates a subgraph of layers into a single layer + in a Keras functional way. This means that the subgraph of layers is + programmatically accessible. Functional Models containing + CompositeLayers can be plotted with `keras.utils.plot_model` + or programmatically edited with 'keras.models.clone_model(call_fn)'. + + `CompositeLayer` can be created in two ways: + + 1. From a list of layers: + + ```python + # Composite layer from a list of layers + composite = layers.CompositeLayer([ + layers.Dense(64, activation='relu'), + layers.Dense(32) + ]) + ``` + + 2. Using a function that defines a graph of layers: + This allows more complex computation graphs. + The first argument of the function will become + the inputs of the composite layer. + + ```python + def layer_fn(x): + x = layers.Dense(64, activation='relu')(x) + outputs = layers.Dense(32)(x) + return outputs + + # Create the composite layer using the function + composite = layers.CompositeLayer(layer_fn) + ``` + + Additional arguments in layer_fn can be used for configuration, + but only the first argument represents the layer's runtime + inputs. Use a dict or list as first argument if your layer + requires multiple inputs. + + ```python + # Additional args for configuration. + # Multiple inputs passed as a list or dict to 'inputs' argument. + def layer_fn(inputs, dense_size=64): + x0 = inputs[0] # inputs is a list + x1 = inputs[1] + y0 = layers.Dense(dense_size, activation='relu')(x0) + y1 = layers.Dense(dense_size, activation='relu')(x1) + return y0 + y1 + + composite = layers.CompositeLayer(layer_fn) + ``` + + Reusable composite layers can be packaged + in a subclass of `CompositeLayer`: + + ```python + # A reusable composite layer + class MyCompositeLayer(CompositeLayer): + + @staticmethod + def my_layer_fn(inputs): + x = layers.Dense(5)(inputs) + return layers.Dense(4)(x) + + def __init__(self, **kwargs): + super().__init__(MyCompositeLayer.my_layer_fn, **kwargs) + ``` + + Args: + layers: Either + - a callable function that defines a computation graph + - or a list of Layer objects to compose sequentially. + name: Optional name for the layer. + """ + + def __new__(cls, *args, **kwargs): + return typing.cast(cls, super().__new__(cls)) + + def __init__(self, layers, name=None, **kwargs): + super().__init__(name=name, **kwargs) + + # Init from either a function that defines the + # layer graph or a sequence of layers. + # Internally, a CompositeLayer can also + # be initialized from a Keras Function. + if not isinstance(layers, Function): + if not ( + (isinstance(layers, (list, tuple)) and len(layers) > 0) + or (callable(layers)) + ): + raise ValueError( + f"CompositeLayer requires a layers parameter that is " + f"either a function that defines the layer's computation " + f"graph or a non-empty list of layers. Got: {layers}" + ) + # error out on wrong layer_fn signature + if callable(layers): + layer_fn = layers + layer_fn_params = list(inspect.signature(layer_fn).parameters) + if len(layer_fn_params) < 1: + raise ValueError( + f"The function used to initialize a CompositeLayer " + f"must take the layer's inputs as its first argument. " + f"Additional arguments may be used for configuration. " + f"If multiple inputs are required at runtime, use a " + f"list or a dictionary. " + f"Got: {layer_fn_params} for {layer_fn}" + ) + + # Constructing from a Keras Function is useful + # internally when deserializing or cloning the layer. + if isinstance(layers, Function): + self._build_from_function(function=layers) + self._arg_layers = None + # defer building until the first call to build() + else: + self._arg_layers = layers + self._function = None + self.built = False + + # Allow calling the layer on raw Python data (e.g list of numbers) + # to be similar to what Functional does. + self._convert_input_args = True + # BUG: this is NOT useful and extra positional args are NOT allowed + # but _convert_input_args=True won't work without this flag. + self._allow_non_tensor_positional_args = False + + # Note: CompositeLayer does not have the following attributes: + # _inputs_struct, _outputs_struct, _inputs, _outputs as in + # Functional model since those are private attributes of Function. + + @property + def inputs(self): + return self._function._inputs + + @property + def outputs(self): + return self._function._outputs + + # Override Operation.input (as in Functional) + @property + def input(self): + return self._function._inputs_struct + + # Override Operation.output (as in Functional) + @property + def output(self): + return self._function._outputs_struct + + # Only call this from __init__ or build() + # otherwise, must handle state locking/unlocking. + def _build_from_function(self, function): + self._function = function + # tracking: compute list of layers from the new function + self._layers = self.layers + self.built = True + + def build(self, input_shape): + # if __init__ from Function, build() should do nothing + assert not isinstance(self._arg_layers, Function) + + def spec_to_input(spec): + # InputSpec shapes have batch size as first dim + return Input( + batch_shape=spec.shape, + dtype=spec.dtype, + name=spec.name, + optional=spec.optional, + ) + + # create appropriate inputs + if hasattr(self, "_manual_input_spec"): + # code path for a manual input spec which may contain + # optional inputs (set with InputSpec(optional=True) + inputs = tree.map_structure(spec_to_input, self.input_spec) + else: + # In this code path, there are no optional inputs. + inputs = tree.map_shape_structure( + # Force the batch size to None wherever possible + lambda shape: Input(batch_shape=(None,)+tuple(shape[1:]) + if len(shape)>=1 else shape, + dtype=self.input_dtype), + input_shape, + ) + + # if "layers" is a callable, call to create the layer graph + if callable(self._arg_layers): + layer_fn = self._arg_layers + outputs = layer_fn(inputs) + self._build_from_function(Function(inputs, outputs, name=self.name)) + # if "layers" is a list or tuple, create the layer graph sequantially + elif ( + isinstance(self._arg_layers, (list, tuple)) + and len(self._arg_layers) > 0 + ): + layers_list = self._arg_layers + x = inputs + for layer in layers_list: + x = layer(x) + self._build_from_function(Function(inputs, x, name=self.name)) + + # remove input param references now that _function is built + self._arg_layers = None + + @property + def layers(self): + """Returns the list of layers contained in this composite layer.""" + # Collect all Layer objects from operations + layers = [] + if self._function: + for operation in self._function.operations: + if isinstance(operation, Layer): + layers.append(operation) + return layers + + @layers.setter + def layers(self, _): + raise AttributeError( + "`CompositeLayer.layers` attribute is reserved and should not " + "be used. Please use another name." + ) + + def call(self, inputs, training=None, mask=None): + # Apply the function with training mode + return run_through_graph_with_training_and_mask( + self._function, inputs, training=training, mask=mask + ) + + def compute_output_shape(self, input_shape): + return self._function.compute_output_shape(input_shape) + + def compute_output_spec(self, inputs, training=None, mask=None): + return self._function.compute_output_spec(inputs) + + def get_config(self): + if not self.built: + raise ValueError( + "This CompositeLayer has not been built yet." + "You need to call `build()` or call the layer on an input." + ) + config = super().get_config() + functional_config = serialize_functional_config(self, self._function) + config.update(functional_config) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + # Extract CompositeLayer specific config + layer_config = {} + for key in ["trainable", "dtype"]: + layer_config[key] = config.pop(key, None) + for key in ["name"]: + layer_config[key] = config.get(key, None) # keep name for Function + + # Recreate the Keras Function + function = function_from_config(Function, config, custom_objects) + # Create instance from Function + instance = cls.__new__(cls) + CompositeLayer.__init__(instance, function, **layer_config) + return instance + + def get_layer(self, name=None, index=None): + """Retrieves a layer based on either its name (unique) or index. + + If `name` and `index` are both provided, `index` will take precedence. + Indices are based on order of horizontal graph traversal (bottom-up). + + Args: + name: String, name of layer. + index: Integer, index of layer. + + Returns: + A layer instance. + """ + if index is not None and name is not None: + raise ValueError( + "Provide only a layer name or a layer index. Received: " + f"index={index}, name={name}." + ) + if index is not None: + if len(self.layers) <= index: + raise ValueError( + f"Was asked to retrieve layer at index {index}" + f" but model only has {len(self.layers)}" + " layers." + ) + else: + return self.layers[index] + + if name is not None: + for layer in self.layers: + if layer.name == name: + return layer + raise ValueError( + f"No such layer: {name}. Existing layers are: " + f"{list(layer.name for layer in self.layers)}." + ) + raise ValueError( + "Provide either a layer name or layer index at `get_layer`." + ) + + @property + def input_spec(self): + if hasattr(self, "_manual_input_spec"): + return self._manual_input_spec + elif self._function: + return compute_input_spec( + self._function._inputs_struct, self._function._inputs + ) + else: + return None + + @input_spec.setter + def input_spec(self, value): + self._manual_input_spec = value + + # Workaroubd for the fact that during input_spec assignemnt, + # Layer.__setattr__ is called first which wraps the assigned + # value in a TrackableList, which then interferes with + # _function._assert_input_compatibility. This cannot be fixed + # by adding DotNotTrackScope in input_spec.setter because + # Layer.__setattr__ is called before input_spec setter. + def __setattr__(self, name, value): + if name == "input_spec": + with tracking.DotNotTrackScope(): + return super().__setattr__(name, value) + return super().__setattr__(name, value) + + diff --git a/keras/src/layers/core/composite_layer_test.py b/keras/src/layers/core/composite_layer_test.py new file mode 100644 index 000000000000..36eda9fddfc1 --- /dev/null +++ b/keras/src/layers/core/composite_layer_test.py @@ -0,0 +1,1376 @@ +import os +import warnings + +import numpy as np +import pytest +from absl.testing import parameterized + +from keras.models import clone_model +from keras.src import applications +from keras.src import backend +from keras.src import initializers +from keras.src import layers +from keras.src import ops +from keras.src import saving +from keras.src import testing +from keras.src import tree +from keras.src.layers.core.composite_layer import CompositeLayer +from keras.src.layers.core.dense import Dense +from keras.src.layers.core.input_layer import Input +from keras.src.layers.input_spec import InputSpec +from keras.src.models.model import Model +from keras.src.models.sequential import Sequential + + +class CompositeLayerTest(testing.TestCase): + def test_basic_flow(self): + def my_layer_fn(inputs): + x = layers.Dense(5, name="dense1")(inputs) + return layers.Dense(4, name="dense2")(x) + + layer = CompositeLayer(my_layer_fn, name="basic") + + self.assertEqual(layer.name, "basic") + self.assertIsInstance(layer, CompositeLayer) + self.assertIsInstance(layer, layers.Layer) + self.assertFalse(layer.built) # Should be lazily built + + # Eager call - should trigger build + in_val = np.random.random((2, 3)) + out_val = layer(in_val) + self.assertEqual(out_val.shape, (2, 4)) + self.assertTrue(layer.built) # Should be built now + + # weights + self.assertEqual(len(layer.weights), 4) + self.assertEqual(layer.weights[0].path, "basic/dense1/kernel") + self.assertEqual(layer.weights[0].shape, (3, 5)) + self.assertEqual(layer.weights[1].path, "basic/dense1/bias") + self.assertEqual(layer.weights[1].shape, (5,)) + self.assertEqual(layer.weights[2].path, "basic/dense2/kernel") + self.assertEqual(layer.weights[2].shape, (5, 4)) + self.assertEqual(layer.weights[3].path, "basic/dense2/bias") + self.assertEqual(layer.weights[3].shape, (4,)) + + # variables + self.assertEqual(len(layer.variables), 4) + self.assertEqual(layer.variables[0].path, "basic/dense1/kernel") + self.assertEqual(layer.variables[0].shape, (3, 5)) + self.assertEqual(layer.variables[1].path, "basic/dense1/bias") + self.assertEqual(layer.variables[1].shape, (5,)) + self.assertEqual(layer.variables[2].path, "basic/dense2/kernel") + self.assertEqual(layer.variables[2].shape, (5, 4)) + self.assertEqual(layer.variables[3].path, "basic/dense2/bias") + self.assertEqual(layer.variables[3].shape, (4,)) + + # Symbolic call + test_input = Input(shape=(3,), batch_size=2) + test_output = layer(test_input) + self.assertEqual(test_output.shape, (2, 4)) + self.assertTrue(layer.built) # Should be built now + + def test_basic_flow_as_a_sublayer(self): + # Build sublayer + sublayer = CompositeLayer([layers.Flatten()]) + + inputs = Input((None, 4, 5)) + outputs = layers.TimeDistributed(sublayer)(inputs) + model = Model(inputs=inputs, outputs=outputs) + + x = np.random.random((2, 3, 4, 5)) + y = model(x) + self.assertEqual(y.shape, (2, 3, 4 * 5)) + + def test_basic_class_flow(self): + class MyCompositeLayer(CompositeLayer): + @staticmethod + def my_layer_fn(inputs): + x = layers.Dense(5)(inputs) + return layers.Dense(4)(x) + + def __init__(self, **kwargs): + super().__init__(MyCompositeLayer.my_layer_fn, **kwargs) + + layer = MyCompositeLayer(name="func_subclass") + + self.assertEqual(layer.name, "func_subclass") + self.assertIsInstance(layer, CompositeLayer) + self.assertIsInstance(layer, layers.Layer) + self.assertFalse(layer.built) # Should be lazily built + + # Eager call - should trigger build + in_val = np.random.random((2, 3)) + out_val = layer(in_val) + self.assertEqual(out_val.shape, (2, 4)) + self.assertTrue(layer.built) # Should be built now + + # Symbolic call + test_input = Input(shape=(3,), batch_size=2) + test_output = layer(test_input) + self.assertEqual(test_output.shape, (2, 4)) + self.assertTrue(layer.built) # Should be built now + + def test_scalar_handling(self): + def scalar_layer_fn(inputs): + # Handle scalar input + return inputs + 1.0 + + layer = CompositeLayer(scalar_layer_fn) + + # Test with scalar added to tensor + in_val = np.zeros((2, 3)) + out_val = layer(in_val) + self.assertAllClose(out_val, np.ones((2, 3))) + + def test_non_mutable_state(self): + def layer_fn(inputs): + x = layers.Dense(5)(inputs) + outputs = layers.Dense(5)(x) + return outputs + + layer = CompositeLayer(layer_fn) + layer.build([2, 3]) + with self.assertRaisesRegex( + ValueError, "You cannot add new elements of state*" + ): + layer.extra_layer = layers.Dense(5) + + def test_multi_output(self): + def multi_output_fn(inputs): + x = layers.Dense(5)(inputs) + output_a = layers.Dense(4)(x) + output_b = layers.Dense(5)(x) + return [output_a, output_b] + + layer = CompositeLayer(multi_output_fn) + + # Eager call + in_val = np.random.random((2, 3)) + out_val = layer(in_val) + self.assertIsInstance(out_val, list) + self.assertEqual(len(out_val), 2) + self.assertEqual(out_val[0].shape, (2, 4)) + self.assertEqual(out_val[1].shape, (2, 5)) + + # Symbolic call + out_val = layer(Input(shape=(3,), batch_size=2)) + self.assertIsInstance(out_val, list) + self.assertEqual(len(out_val), 2) + self.assertEqual(out_val[0].shape, (2, 4)) + self.assertEqual(out_val[1].shape, (2, 5)) + + def test_dict_io(self): + def dict_io_fn(inputs): + # Inputs is expected to be a dict with keys 'a' and 'b' + x = inputs["a"] + inputs["b"] + x = layers.Dense(5)(x) + return layers.Dense(4)(x) + + layer = CompositeLayer(dict_io_fn) + + # Test with dictionary input + in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 3))} + out_val = layer(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + def test_layer_fn_init(self): + # Test initialization with a layer function + def my_layer_fn(inputs): + x = layers.Dense(64, activation="relu")(inputs) + return layers.Dense(32)(x) + + layer = CompositeLayer(my_layer_fn, name="layer_fn_composite") + + self.assertEqual(layer.name, "layer_fn_composite") + self.assertIsInstance(layer, CompositeLayer) + self.assertIsInstance(layer, layers.Layer) + self.assertFalse(layer.built) # Should be lazily built + + # Eager call - should trigger build + in_val = np.random.random((2, 32)) + out_val = layer(in_val) + self.assertEqual(out_val.shape, (2, 32)) + self.assertTrue(layer.built) # Should be built now + + # Check that the layers are properly created after building + self.assertEqual(len(layer.layers), 3) # Exactly 3 layers, incl. Input + + # Symbolic call + test_input = Input(shape=(32,), batch_size=2) + test_output = layer(test_input) + self.assertEqual(test_output.shape, (2, 32)) + + def test_sequential_init(self): + # Test initialization with a list of layers + layer = CompositeLayer( + [ + layers.Dense(64, activation="relu"), + layers.Dense(32), + ], + name="sequential_composite", + ) + + self.assertEqual(layer.name, "sequential_composite") + self.assertIsInstance(layer, CompositeLayer) + self.assertIsInstance(layer, layers.Layer) + + # Check that the layers are properly stored + layer.build(input_shape=(2, 32)) + self.assertEqual(len(layer.layers), 3) # 3 layers incl. Input + self.assertIsInstance(layer.layers[0], layers.InputLayer) + self.assertIsInstance(layer.layers[1], layers.Dense) + self.assertIsInstance(layer.layers[2], layers.Dense) + + # Eager call + in_val = np.random.random((2, 32)) + out_val = layer(in_val) + self.assertEqual(out_val.shape, (2, 32)) + + # Symbolic call + test_input = Input(shape=(32,), batch_size=2) + test_output = layer(test_input) + self.assertEqual(test_output.shape, (2, 32)) + + def test_multi_layer_sequential_init(self): + # Test with more complex sequential architecture + layer = CompositeLayer( + [ + layers.Dense(64, activation="relu", input_shape=(32,)), + layers.Dropout(0.5), + layers.BatchNormalization(), + layers.Dense(32, activation="relu"), + layers.Dense(16), + ] + ) + + # Test forward pass + in_val = np.random.random((4, 32)) + out_val = layer(in_val) + self.assertEqual(out_val.shape, (4, 16)) + + # Check layer composition, incl. InputLayer + self.assertEqual(len(layer.layers), 6) + + def test_initialization_errors(self): + # Test with invalid layers parameter type + with self.assertRaisesRegex( + ValueError, "is either a function .* or a .* list of layers." + ): + CompositeLayer("not_valid") + + # Test error when layers list is empty + with self.assertRaisesRegex( + ValueError, "is either a function .* or a .* list of layers." + ): + CompositeLayer([]) + + x = y = 0 + def layer_fn(): + return x + y + + with self.assertRaisesRegex( + ValueError, "must take the layer's inputs as its first argument" + ): + CompositeLayer(layer_fn) + + # This is allowed + def layer_fn(inputs, constant_bias=0.1): + return inputs + constant_bias + + CompositeLayer(layer_fn)(np.random.random((2, 32))) + + def test_serialization(self): + def run_compsite_layer_serialization_test(layer): + # this torch-specific attribute is not supposed + # to be serialized / deserialized so delete it + # before checking serialization / deserialization. + if backend.backend() == "torch": + if hasattr(layer, "_torch_params"): + del layer._torch_params + self.run_class_serialization_test(layer) + + # Test basic model + def layer_fn(x): + return layers.Dense(3)(x) + + layer = CompositeLayer(layer_fn, trainable=False) + inputs = Input(shape=(3,), batch_size=2) + layer(inputs) # build the layer + if backend.backend() == "torch": + del layer._torch_params + run_compsite_layer_serialization_test(layer) + + # Test multi-io model + def layer_fn(inputs): + input_a, input_b = inputs + xa = layers.Dense(5, name="middle_a")(input_a) + xb = layers.Dense(5, name="middle_b")(input_b) + output_a = layers.Dense(4, name="output_a")(xa) + output_b = layers.Dense(4, name="output_b")(xb) + return (output_a, output_b) + + layer = CompositeLayer(layer_fn, name="func") + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + layer([input_a, input_b]) # build the layer + if backend.backend() == "torch": + del layer._torch_params + run_compsite_layer_serialization_test(layer) + + # Test model that includes floating ops + def layer_fn(inputs): + input_a, input_b = inputs + x = input_a + input_b + x = layers.Dense(5, name="middle")(x) + output_a = layers.Dense(4, name="output_a")(x) + output_b = layers.Dense(4, name="output_b")(x) + return (output_a, output_b) + + layer = CompositeLayer(layer_fn, name="func") + input_a = Input(shape=(3,), batch_size=2, name="input_a") + input_b = Input(shape=(3,), batch_size=2, name="input_b") + layer([input_a, input_b]) # build the layer + if backend.backend() == "torch": + del layer._torch_params + run_compsite_layer_serialization_test(layer) + + # Test model with dict i/o + def layer_fn(inputs): + input_a = inputs["a"] + input_b = inputs["b"] + x = input_a + input_b + x = layers.Dense(5)(x) + return layers.Dense(4)(x) + + layer = CompositeLayer(layer_fn, name="func") + input_a = Input(shape=(3,), batch_size=2, name="a") + input_b = Input(shape=(3,), batch_size=2, name="b") + layer({"a": input_a, "b": input_b}) # build the layer + if backend.backend() == "torch": + del layer._torch_params + run_compsite_layer_serialization_test(layer) + + def test_config_serialization(self): + # Test serialization of sequential initialization + original_layer = CompositeLayer( + layers=[ + layers.Dense(64, activation="relu", input_shape=(32,)), + layers.Dense(32), + ] + ) + original_layer.build([None, 16]) + + config = original_layer.get_config() + self.assertEqual(config["name"], "composite_layer") + self.assertEqual(len(config["layers"]), 3) + + # Recreate from config + recreated_layer = CompositeLayer.from_config(config) + self.assertEqual(config["name"], "composite_layer") + self.assertEqual(len(recreated_layer.layers), 3) + + # Test the recreated layer works + # Eager call + in_val = np.random.random((2, 16)) + out_val = recreated_layer(in_val) + self.assertEqual(out_val.shape, (2, 32)) + + # Symbolic call + input = Input(shape=(16,), batch_size=5) + out_val = recreated_layer(input) + self.assertEqual(out_val.shape, (5, 32)) + + # Test serialization of layer_fn initialization + def test_layer_fn(inputs): + x = layers.Dense(64)(inputs) + return layers.Dense(10)(x) + + composite = CompositeLayer(test_layer_fn) + # Build it first by calling it + in_val = np.random.random((2, 20)) + composite(in_val) + + # Save and recreate from config + config = composite.get_config() + recreated_layer = CompositeLayer.from_config(config) + self.assertTrue(recreated_layer.built) + + # Test the recreated layer works + # Eager call + out_val = recreated_layer(in_val) + self.assertEqual(out_val.shape, (2, 10)) + + # Symbolic call + input = Input(shape=(20,), batch_size=8) + out_val = recreated_layer(input) + # Test with a different batch size: + # The layer was first called with batch_size=2 + # and is now called with batch_size=8 + self.assertEqual(out_val.shape, (8, 10)) + + def test_class_serialization(self): + class MyCompositeLayer(CompositeLayer): + @staticmethod + def my_layer_fn(inputs): + x = layers.Dense(5)(inputs) + return layers.Dense(4)(x) + + def __init__(self, **kwargs): + super().__init__(MyCompositeLayer.my_layer_fn, **kwargs) + + layer = MyCompositeLayer(name="func_subclass") + layer(Input(shape=(3,), batch_size=2)) + + self.assertTrue(layer.built) + + config = layer.get_config() + restored_layer = MyCompositeLayer.from_config(config) + + self.assertEqual(restored_layer.name, "func_subclass") + self.assertIsInstance(restored_layer, MyCompositeLayer) + self.assertIsInstance(restored_layer, CompositeLayer) + self.assertIsInstance(restored_layer, layers.Layer) + self.assertTrue(restored_layer.built) + + # Eager call - should trigger build + in_val = np.random.random((2, 3)) + out_val = restored_layer(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + test_input = Input(shape=(3,), batch_size=2) + test_output = layer(test_input) + self.assertEqual(test_output.shape, (2, 4)) + + def test_input_dict_with_extra_field(self): + def layer_fn(inputs): + x = inputs["a"] * 5 + outputs = x + 2 + return outputs + + layer = CompositeLayer(layer_fn) + layer({"a": Input((3,))}) # build the layer + + # Eager call with extra value in dict + in_val = { + "a": np.random.random((2, 3)), + "b": np.random.random((2, 1)), + } + with pytest.warns() as record: + out_val = layer(in_val) + self.assertLen(record, 1) + self.assertStartsWith( + str(record[0].message), + "The structure of `inputs` doesn't match " + "the expected structure", + ) + self.assertEqual(out_val.shape, (2, 3)) + + def test_warning_for_mismatched_inputs_structure(self): + def is_input_warning(w): + return str(w.message).startswith( + "The structure of `inputs` doesn't match the expected structure" + ) + + def layer_fn(inputs): + i1 = inputs["i1"] + i2 = inputs["i2"] + return layers.Add()([i1, i2]) + + composite_layer = CompositeLayer(layer_fn) + composite_layer( + {"i1": Input((2,)), "i2": Input((2,))} + ) # build the layer + with pytest.warns() as warning_logs: + composite_layer([np.ones((2, 2)), np.ones((2, 2))]) + self.assertLen(list(filter(is_input_warning, warning_logs)), 1) + + # No warning for mismatched tuples and lists. + def layer_fn2(inputs): + i1, i2 = inputs + return layers.Add()([i1, i2]) + + composite_layer = CompositeLayer(layer_fn2) + composite_layer( + [Input((2,)), Input((2,))] + ) # build the layer with a list + with warnings.catch_warnings(record=True) as warning_logs: + # call the layer with a tuple + composite_layer((np.ones((2, 2)), np.zeros((2, 2)))) + self.assertLen(list(filter(is_input_warning, warning_logs)), 0) + + @parameterized.named_parameters( + ("list", list), + ("tuple", tuple), + ("dict", dict), + ) + def test_restored_multi_output_type(self, out_type): + def layer_fn(inputs): + x = layers.Dense(5)(inputs) + output_a = layers.Dense(4)(x) + output_b = layers.Dense(5)(x) + if out_type is dict: + outputs = {"a": output_a, "b": output_b} + else: + outputs = out_type([output_a, output_b]) + return outputs + + layer = CompositeLayer(layer_fn) + layer.build(input_shape=(2, 3)) + + config = layer.get_config() + layer_restored = CompositeLayer.from_config(config) + + # Eager call + in_val = np.random.random((2, 3)) + out_val = layer_restored(in_val) + self.assertIsInstance(out_val, out_type) + + # Symbolic call + out_val = layer_restored(Input(shape=(3,), batch_size=2)) + self.assertIsInstance(out_val, out_type) + + def test_layer_getters(self): + def layer_fn(inputs): + # Test mixing ops and layers + input_a = inputs["a"] + input_b = inputs["b"] + x = input_a + input_b + x = layers.Dense(5, name="dense_1")(x) + outputs = layers.Dense(4, name="dense_2")(x) + return outputs + + layer = CompositeLayer(layer_fn) + + layer.build({"a": (2, 3), "b": (2, 3)}) + + # Check layer composition, incl. InputLayer(s) + self.assertEqual(len(layer.layers), 4) + self.assertEqual(len(layer._function._operations), 5) + self.assertEqual(layer.get_layer(index=2).name, "dense_1") + self.assertEqual(layer.get_layer(index=3).name, "dense_2") + self.assertEqual(layer.get_layer(name="dense_1").name, "dense_1") + + def test_training_arg(self): + class Canary(layers.Layer): + def call(self, x, training=False): + assert training + return x + + def compute_output_spec(self, x, training=False): + return ops.KerasTensor(x.shape, dtype=x.dtype) + + # Test with layer_fn initialization + def layer_fn(inputs): + return Canary()(inputs) + + layer_fn_layer = CompositeLayer(layer_fn) + layer_fn_layer(np.random.random((2, 3)), training=True) + + # Test with sequential initialization + sequential_layer = CompositeLayer([Canary()]) + sequential_layer(np.random.random((2, 3)), training=True) + + def test_mask_arg(self): + # TODO (same as in functional_test.py) + pass + + def test_rank_standardization(self): + def layer_fn(x): + return layers.Dense(4)(x) + + # Downranking + layer = CompositeLayer(layer_fn) + layer.build((8, 10)) + out_val = layer(np.random.random((8, 10, 1))) + self.assertEqual(out_val.shape, (8, 4)) + + # Upranking + layer = CompositeLayer(layer_fn) + layer.build((8, 10, 1)) + out_val = layer(np.random.random((8, 10))) + self.assertEqual(out_val.shape, (8, 10, 4)) + + @pytest.mark.requires_trainable_backend + def test_dtype_standardization(self): + def layer_fn(x): + float_input = x["float"] + int_input = x["int"] + float_output = float_input + 2 + int_output = int_input + 2 + return (float_output, int_output) + + # Contrary to a Functional Model, a CompositeLayer has + # only one input dtype. All of its inputs will be created + # in build() with the same dtype. If multiple inputs with + # different dtypes are needed, use a Functional Model. + + # layer with dtype: forces inputs to that dtype + layer = CompositeLayer(layer_fn, dtype="float16") + + # symbilic call + float_data, int_data = layer( + { + "float": Input((2, 2), dtype="float32"), + "int": Input((2, 2), dtype="int32"), + } + ) + + self.assertEqual(backend.standardize_dtype(float_data.dtype), "float16") + self.assertEqual(backend.standardize_dtype(int_data.dtype), "float16") + + # eager call + float_data, int_data = layer( + { + "float": np.ones((8, 2, 2), dtype="float32"), + "int": np.ones((8, 2, 2), dtype="int32"), + } + ) + + self.assertEqual(backend.standardize_dtype(float_data.dtype), "float16") + self.assertEqual(backend.standardize_dtype(int_data.dtype), "float16") + + def test_bad_input_spec(self): + # Single input + def layer_fn(x): + return layers.Dense(2)(x) + + layer = CompositeLayer(layer_fn) + layer.build((None, 4)) + with self.assertRaisesRegex( + ValueError, + r"Input .* is incompatible .* " + r"expected shape=\(None, 4\), found shape=\(2, 3\)", + ): + layer(np.zeros((2, 3))) + with self.assertRaisesRegex(ValueError, "expects 1 input"): + layer([np.zeros((2, 4)), np.zeros((2, 4))]) + + # List input + def layer_fn(inputs): + input_a, input_b = inputs + x = input_a + input_b + return layers.Dense(2)(x) + + layer = CompositeLayer(layer_fn) + input_a = Input(shape=(4,), name="a") + input_b = Input(shape=(4,), name="b") + layer([input_a, input_b]) # build the layer + with self.assertRaisesRegex(ValueError, r"expects 2 input\(s\)"): + layer(np.zeros((2, 3))) + with self.assertRaisesRegex( + ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)" + ): + layer([np.zeros((2, 3)), np.zeros((2, 4))]) + + # Dict input + def layer_fn(inputs): + input_a = inputs["a"] + input_b = inputs["b"] + y = input_a + input_b + return layers.Dense(2)(y) + + layer = CompositeLayer(layer_fn) + input_a = Input(shape=(4,), name="a") + input_b = Input(shape=(4,)) + layer({"a": input_a, "b": input_b}) # build the layer + with self.assertRaisesRegex( + ValueError, r"expects 2 input\(s\), but it received 1 input" + ): + layer(np.zeros((2, 3))) + with self.assertRaisesRegex( + ValueError, r"expected shape=\(None, 4\), found shape=\(2, 3\)" + ): + layer({"a": np.zeros((2, 3)), "b": np.zeros((2, 4))}) + + def test_manual_input_spec(self): + def layer_fn(x): + return layers.Dense(2)(x) + + layer = CompositeLayer(layer_fn) + layer.input_spec = InputSpec(shape=(None, 4, 3)) + with self.assertRaisesRegex( + ValueError, + r"expected shape=\(None, 4, 3\), found shape=\(8, 3, 3\)", + ): + layer(np.zeros((8, 3, 3))) + layer(np.zeros((8, 4, 3))) + + def test_deeply_nested_composite_layer(self): + def layer_fn(x): + # input x is: {"1": i1, "others": {"2": i2, "3": i3}} + i1 = x["1"] + i2 = x["others"]["2"] + i3 = x["others"]["3"] + o1, o2, o3 = ( + layers.Dense(1)(i1), + layers.Dense(2)(i2), + layers.Dense(3)(i3), + ) + return {"1": o1, "others": {"2": o2, "3": o3}} + + composite_layer = CompositeLayer(layer_fn) + out_eager = composite_layer( + { + "1": np.ones((8, 1)), + "others": {"2": np.ones((8, 2)), "3": np.ones((8, 3))}, + } + ) + out_symbolic = composite_layer( + { + "1": Input((1,), batch_size=8), + "others": { + "2": Input((2,), batch_size=8), + "3": Input((3,), batch_size=8), + }, + } + ) + for out in [out_eager, out_symbolic]: + self.assertIsInstance(out, dict) + self.assertEqual(set(out.keys()), {"1", "others"}) + self.assertEqual(out["1"].shape, (8, 1)) + self.assertIsInstance(out["others"], dict) + self.assertEqual(set(out["others"].keys()), {"2", "3"}) + self.assertEqual(out["others"]["2"].shape, (8, 2)) + self.assertEqual(out["others"]["3"].shape, (8, 3)) + + @pytest.mark.requires_trainable_backend + def test_model_with_composite_layers_serialization(self): + def layer_fn(x): + # input x is: {"1": i1, "others": {"2": i2, "3": i3}} + i1 = x["1"] + i2 = x["others"]["2"] + i3 = x["others"]["3"] + o1, o2, o3 = ( + layers.Dense(1)(i1), + layers.Dense(2)(i2), + layers.Dense(3)(i3), + ) + return {"1": o1, "others": {"2": o2, "3": o3}} + + composite_layer = CompositeLayer(layer_fn) + symbolic_input = { + "1": Input((1,)), + "others": {"2": Input((2,)), "3": Input((3,))}, + } + y = composite_layer(symbolic_input) + y = layers.Concatenate()([y["1"], y["others"]["2"], y["others"]["3"]]) + output = layers.Dense(4)(y) + model = Model(symbolic_input, output) + + temp_filepath = os.path.join(self.get_temp_dir(), "deeply_nested.keras") + model.save(temp_filepath) + loaded_model = saving.load_model(temp_filepath) + + num_input = { + "1": np.ones((8, 1)), + "others": {"2": np.ones((8, 2)), "3": np.ones((8, 3))}, + } + out_eager = model(num_input) + new_out_eager = loaded_model(num_input) + self.assertAllClose(out_eager, new_out_eager) + + def test_for_composite_layer_in_sequential(self): + if backend.image_data_format() == "channels_first": + image_size = (3, 256, 256) + else: + image_size = (256, 256, 3) + base_model = applications.mobilenet.MobileNet( + include_top=False, weights=None + ) + layer = CompositeLayer( + [ + layers.Conv2D(32, (3, 3)), + layers.Conv2D(64, (3, 3)), + layers.Conv2D(128, (3, 3)), + ] + ) + model = Sequential() + model.add(layers.Input(shape=image_size)) + model.add(base_model) + model.add(layer) + model.add(layers.GlobalAveragePooling2D()) + model.add(layers.Dense(7, activation="softmax")) + # eager call + model(np.random.random((4,) + image_size)) + # symbolic call + model(Input(shape=image_size)) + # serialization + config = model.get_config() + model = Sequential.from_config(config) + # eager call + model(np.random.random((4,) + image_size)) + # symbolic call + model(Input(shape=image_size)) + + def test_add_loss(self): + # TODO (same as in functional_test.py) + pass + + def test_layers_setter(self): + layer = CompositeLayer([layers.Dense(4)]) + + with self.assertRaisesRegex(AttributeError, "attribute is reserved"): + layer.layers = [layers.Dense(5)] + + def test_list_input_with_dict_build(self): + def layer_fn(inputs): + x1 = inputs["IT"] + x2 = inputs["IS"] + return layers.subtract([x1, x2]) + + layer = CompositeLayer(layer_fn) + x1 = Input((10,)) + x2 = Input((10,)) + layer({"IT": x1, "IS": x2}) # build the layer + x1 = ops.ones((1, 10)) + x2 = ops.zeros((1, 10)) + # eager call works + layer({"IT": x1, "IS": x2}) + # Note: the test fails here only because the order of dict + # keys "IT", "IS" is different from the sorted order of the + # keys "IS", "IT". Otherwise, passing a list of inputs to + # a model expecting a dictionary of inputs seems to be allowed, + # as long as flattening the dict does not result in reordering. + with self.assertRaisesRegex( + ValueError, + "The structure of `inputs` doesn't match the expected structure", + ): + layer([x1, x2]) + + @pytest.mark.requires_trainable_backend + def test_functional_subclass_serialization(self): + class FuncSubclass(CompositeLayer): + @staticmethod + def layer_fn(x): + y = layers.Dense(8)(x) + return layers.Dense(4)(y) + + def __init__(self, name=None, **kwargs): + super().__init__(FuncSubclass.layer_fn, name, **kwargs) + + inputs = Input((4,), name="input") + y = FuncSubclass()(inputs) + model = Model(inputs, y) + data = ops.ones((8, 4)) + output1 = model(data) # build the model + temp_filepath = os.path.join(self.get_temp_dir(), "comp_subclass.keras") + model.save(temp_filepath) + + # Note: this recreates the layer by calling FuncSubclass.__init__ + # and does *not* test the functional.function_from_config method. + + loaded_model = saving.load_model( + temp_filepath, custom_objects={"FuncSubclass": FuncSubclass} + ) + + output2 = loaded_model(data) + self.assertAllClose(output1, output2) + + def test_functional_in_functional_with_reuse_serialization(self): + ini = initializers.Ones() + + # sub-functional + def layer_fn(inputs): + y = layers.Dense(6, kernel_initializer=ini)(inputs) + return layers.Dense(8, kernel_initializer=ini)(y) + + sub_layer1 = CompositeLayer(layer_fn) + + comp_layer = CompositeLayer( + [ + sub_layer1, + sub_layer1, # reuse + ] + ) + + data = ops.ones((2, 8)) + output1 = comp_layer(data) + + # this recreates the model from the saved functional graph + # and *does* test the functional.function_from_config method. + + config = comp_layer.get_config() + loaded_model = CompositeLayer.from_config(config) + + # check the model works + output2 = loaded_model(data) + # check both models return the same results + # (weights were initialized deterministically) + self.assertAllClose(output1, output2) + + @pytest.mark.requires_trainable_backend + def test_functional_in_functional_with_reuse_saving(self): + # sub-functional + def layer_fn(inputs): + y = layers.Dense(6)(inputs) + return layers.Dense(8)(y) + + sub_layer1 = CompositeLayer(layer_fn) + + inputs = Input((8,)) + y1 = sub_layer1(inputs) + outputs = sub_layer1(y1) # reuse + + model = Model(inputs, outputs) + data = ops.ones((2, 8)) + output1 = model(data) # build the model + + # this recreates the model from the saved functional graph + # and *does* test the functional.function_from_config method. + + temp_filepath = os.path.join( + self.get_temp_dir(), "func_subclass_reuse.keras" + ) + model.save(temp_filepath) + loaded_model = saving.load_model(temp_filepath) + + # check the model works + output2 = loaded_model(data) + # check both models return the same results + # (weights were initialized deterministically) + self.assertAllClose(output1, output2) + + @pytest.mark.requires_trainable_backend + def test_composite_in_functional_model(self): + class ConvStack(CompositeLayer): + def __init__(self, **kwargs): + @staticmethod + def layer_fn(x): + y = layers.Conv2D( + 12, + kernel_size=(3, 3), + padding="same", + activation="relu", + name="c1", + )(x) + y = layers.Conv2D( + 16, (3, 3), padding="same", activation="relu", name="c2" + )(y) + z = layers.Conv2D( + 16, kernel_size=(1, 1), activation="relu", name="c3" + )(x) + return y + z + + super().__init__(layer_fn, **kwargs) + + class RegulStack(CompositeLayer): + def __init__(self, **kwargs): + super().__init__( + [layers.MaxPooling2D(pool_size=(2, 2)), layers.Flatten()], + **kwargs, + ) + + class DenseStack(CompositeLayer): + def __init__(self, **kwargs): + super().__init__( + [ + layers.Dense(128, activation="relu"), + layers.Dropout(0.25), + ], + **kwargs, + ) + + if backend.config.image_data_format() == "channels_first": + input = Input(shape=(1, 28, 28)) + data = np.ones((2, 1, 28, 28)) + else: + input = Input(shape=(28, 28, 1)) + data = np.ones((2, 28, 28, 1)) + composite_layer1 = ConvStack(name="c1") + x = composite_layer1(input) + composite_layer2 = ConvStack(name="c2") + x = composite_layer2(x) + composite_layer3 = RegulStack(name="r1") + x = composite_layer3(x) + composite_layer4 = DenseStack(name="d1") + x = composite_layer4(x) + output = layers.Dense(10, activation="softmax", name="head")(x) + model = Model(input, output, name="func_model") + model(input) # check the model builds + + def is_spec_equal(spec1, spec2): + tree.assert_same_structure(spec1, spec2) + + def compare_spec(s1, s2): + return s1.shape == s2.shape and s1.dtype == s2.dtype + + result = tree.map_structure(compare_spec, spec1, spec2) + return all(tree.flatten(result)) + + def get_image_format_shape(channeles_first_shape, channels_last_shape): + if backend.config.image_data_format() == "channels_first": + return channeles_first_shape + else: + return channels_last_shape + + # check the layers were built correctly + shape = get_image_format_shape((None, 1, 28, 28), (None, 28, 28, 1)) + self.assertTrue( + is_spec_equal(composite_layer1.input_spec, InputSpec(shape=shape)) + ) + shape = get_image_format_shape((None, 16, 28, 28), (None, 28, 28, 16)) + self.assertTrue( + is_spec_equal(composite_layer2.input_spec, InputSpec(shape=shape)) + ) + shape = get_image_format_shape((None, 16, 28, 28), (None, 28, 28, 16)) + self.assertTrue( + is_spec_equal(composite_layer3.input_spec, InputSpec(shape=shape)) + ) + self.assertTrue( + is_spec_equal( + composite_layer4.input_spec, + InputSpec(shape=(None, 28 * 28 * 16 // 4)), + ) + ) + + # test model serialization with weights + output1 = model(data) + temp_filepath = os.path.join(self.get_temp_dir(), "func_nested.keras") + model.save(temp_filepath) + loaded_model = saving.load_model( + temp_filepath, + custom_objects={ + "ConvStack": ConvStack, + "RegulStack": RegulStack, + "DenseStack": DenseStack, + }, + ) + + output2 = loaded_model(data) + self.assertAllClose(output1, output2) + + def test_optional_inputs(self): + class OptionalInputLayer(layers.Layer): + def call(self, x, y=None): + if y is not None: + return x + y + return x + + def compute_output_shape(self, x_shape): + return x_shape + + def layer_fn(x): + x1 = x[0] + x2 = x[1] + return OptionalInputLayer()(x2, x1) + + layer = CompositeLayer(layer_fn) + + # declare the first arg as optional + input_spec = [ + InputSpec(shape=(None, 2), optional=True), + InputSpec(shape=(None, 2)), + ] + layer.input_spec = input_spec + + # symbolic test + x1 = Input((2,)) + x2 = Input((2,)) + out = layer([x1, x2]) + # Eager test + data = np.ones((8, 2)) + out = layer([None, data]) + self.assertAllClose(out, data) + out = layer([data, data]) + self.assertAllClose(out, data * 2) + + # Error message when passing None to a non-optional input + with self.assertRaisesRegex( + ValueError, "Optional inputs must be declared" + ): + out = layer([data, None]) + + # This is still problematic because of spurius + # wrapping of _manual_input_spec in ListWrapper etc... + # def test_list_vs_tuple_input_spec2(self): + # def layer_fn(inputs): + # a, l = inputs + # b, c = l + # a = Dense(8)(a) + # b = Dense(8)(b) + # c = Dense(8)(c) + # return a+b+c + + # layer = CompositeLayer(layer_fn) + # layer.input_spec = (InputSpec(shape=(None, 8)), + # [InputSpec(shape=(None, 8)), + # InputSpec(shape=(None, 8))]) + # sym_inputs = (Input(shape=(None, 8)), + # [Input(shape=(None, 8)), + # Input(shape=(None, 8))]) + # layer(sym_inputs) + + def test_list_vs_tuple_input_spec(self): + def layer_fn(inputs): + a, b = inputs + a = Dense(8)(a) + b = Dense(8)(b) + return a+b + + layer = CompositeLayer(layer_fn) + layer.input_spec = (InputSpec(shape=(None, 8)), # tuple + InputSpec(shape=(None, 8))) + sym_inputs = [Input(batch_shape=(None, 8)), # list + Input(batch_shape=(None, 8))] + layer(sym_inputs) + layer.input_spec = [InputSpec(shape=(None, 8)), # list + InputSpec(shape=(None, 8))] + sym_inputs = (Input(batch_shape=(None, 8)), # tuple + Input(batch_shape=(None, 8))) + layer(sym_inputs) + + def test_no_batch_input(self): + + def layer_fn(inputs): + x, y, idx = inputs # idx does not have a batch size + # this happens in LLM text generation + # with cache_update_idx + start = [0, idx, 0, 0] + return ops.slice_update(x, start, y) + + layer = CompositeLayer(layer_fn) + + spec = (InputSpec(shape=(None, 16, 4, 8)), + InputSpec(shape=(None, 1, 4, 8)), + InputSpec(shape=(), dtype="int32")) # no batch size + # input with no batch size defined through manual input_spec + layer.input_spec = spec + + #layer(inputs) + layer((np.zeros((32, 16, 4, 8)), + np.zeros((32, 1, 4, 8)), + backend.convert_to_tensor(0))) + + def layer_fn(inputs): + x, mul = inputs # mul does not have a batch size + return x * mul + + layer = CompositeLayer(layer_fn) + + # input with no batch size defined by calling the layer + inputs = (Input(shape=(16,)), + Input(batch_shape=())) + + layer(inputs) # symbolic call + # numerical call + result = layer((np.ones((32, 16)), + backend.convert_to_tensor(2))) + self.assertAllClose(result, np.ones((32, 16)) * 2) + + # Keeping this as a manual test so that pydot + # and graphvizare not required for testing. + # def test_plot(self): + # def layer_fn(x): + # y = layers.Dense(8)(x) + # y = layers.Dense(4)(y) + # return y + + # layer = CompositeLayer(layer_fn) + + # x = Input((4,)) + # y = layers.Dense(8)(x) + # y = layers.Dense(4)(y) + # model = Model(x,y) + + # model2 = Sequential([ + # layers.Dense(8), + # layers.Dense(8) + # ]) + + # x = Input((12,)) + # y = layer(x) + # y = model(y) + # y = model2(y) + # model = Model(x,y) + + # data = np.random.uniform(size=(4, 12)) + # model(data) + + # utils.plot_model(model, expand_nested=True) + + def test_clone_model(self): + const_init = initializers.Ones() + zero_init = initializers.Zeros() + + # alternative dense implementation with dict output + class AltDense(layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.units = units + + def build(self, input_shape): + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer=const_init, + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), + initializer=zero_init, + trainable=True, + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + data = np.random.uniform(size=(4, 12)) + + # CompositeLayer using regular dense layers + def layer_fn(x): + y = layers.Dense( + 8, kernel_initializer=const_init, name="original1" + )(x) + y = layers.Dense( + 8, kernel_initializer=const_init, name="original2" + )(y) + return y + + layer = CompositeLayer(layer_fn) + + # Composite layer instatiated as a subclass of CompositeLayer + # It is cloned as a vanilla CompositeLayer for now. + class FuncSub(CompositeLayer): + def __init__(self, name=None, **kwargs): + super().__init__(layer_fn, name, **kwargs) + + flayer = FuncSub() + + x = Input((12,)) + y = layers.Dense(8, kernel_initializer=const_init, name="original3")(x) + y = layer(y) + y = flayer(y) # subclass layer + y = layer(y) # shared layer + model = Model(x, y) + # build the model + model(data) + + for variable in model.variables: + self.assertContainsSubsequence(variable.path, "original") + + # replace regular dense layers with alternative + # dense layers and rewire dict output + def replace_fn(layer, *args, **kwargs): + if isinstance(layer, layers.Dense): + return AltDense(layer.units)(*args, **kwargs) + else: + return layer(*args, **kwargs) # pass-through + + # clone function thas does not do any layer cloning + # but only creates a new layer graph. + model2 = clone_model( + model, + input_tensors=x, + # everyhting is done in call_function + clone_function=lambda x: x, + call_function=replace_fn, + recursive=True, + ) + + model2(data) + + # original model is unchanged + for variable in model.variables: + self.assertContainsSubsequence(variable.path, "original") + + # new model has new AltDense layers + for variable in model2.variables: + self.assertContainsSubsequence(variable.path, "alt_dense") + + self.assertEqual(len(model.layers), len(model2.layers)) + for layer1, layer2 in zip(model.layers, model2.layers): + if isinstance(layer1, layers.Dense): + self.assertTrue(layer2.__class__ is AltDense) + # A subclass of CompositeLayer is cloned as CompositeLayer for now + elif isinstance(layer1, FuncSub): + self.assertTrue( + layer2.__class__ is CompositeLayer + or layer2.__class__ is FuncSub + ) + else: + self.assertEqual(layer1.__class__, layer2.__class__) + + self.assertAllClose(model(data), model2(data)) + + def test_clone_recursive(self): + def layer_fn1(inputs): + return layers.Dense(32, name="dense_2")(inputs) + + layer1 = CompositeLayer(layer_fn1, name="sub") + + def layer_fn2(inputs): + return layer1(inputs) + + slayer = CompositeLayer(layer_fn2, name="subfunc") + + inputs = layers.Input(shape=(16, 32)) + outputs = slayer(inputs) + model = Model(inputs, outputs) + + def call_fn(layer, *args, **kwargs): + if isinstance(layer, layers.Dense): + new_layer = layers.Dense(layer.units, name="dense_modified") + return new_layer(*args, **kwargs) + return layer(*args, **kwargs) + + new_model = clone_model(model, call_function=call_fn, recursive=True) + sub = new_model.get_layer("subfunc").get_layer("sub") + self.assertEqual(sub.layers[1].name, "dense_modified") + + def test_clone_with_input_spec(self): + def layer_fn(inputs): + return layers.Dense(12)(inputs) + + layer1 = CompositeLayer(layer_fn) + layer1.input_spec = InputSpec(shape=(None, 12), optional=True) + + def layer_fn2(inputs): + x = layers.Dense(12)(inputs) + return layer1(x) + + layer2 = CompositeLayer(layer_fn2) + layer2.input_spec = InputSpec(shape=(None, 12), optional=True) + + layer2(ops.ones(shape=(8,12))) # build it + + def call_fn(layer, *args, **kwargs): + return layer(*args, **kwargs) + + new_layer = clone_model(layer2, clone_function=lambda x:x, + call_function=call_fn, + recursive=True) + self.assertEqual(new_layer.input_spec, layer2.input_spec) + self.assertEqual(new_layer.layers[2].input_spec, layer1.input_spec) + + def test_clone_before_build(self): + def layer_fn(inputs): + return layers.Dense(12)(inputs) + + layer = CompositeLayer(layer_fn) + + def call_fn(layer, *args, **kwargs): + return layer(*args, **kwargs) + + with self.assertRaisesRegex( + ValueError, "model has no graph of layers. It is probably not built" + ): + clone_model(layer, clone_function=lambda x:x, + call_function=call_fn, + recursive=True) + + + def test_build_twice(self): + def layer_fn(inputs): + return layers.Dense(5)(inputs) + + layer = CompositeLayer(layer_fn) + layer.build([2, 3]) + + id1 = id(layer.layers[0]) + id2 = id(layer.layers[1]) + + # calling build a second time should do nothing + layer.build([2, 3]) + + self.assertEqual(id1, id(layer.layers[0])) + self.assertEqual(id2, id(layer.layers[1])) diff --git a/keras/src/layers/input_spec.py b/keras/src/layers/input_spec.py index 25e4c8d9cda4..9376691fcf00 100644 --- a/keras/src/layers/input_spec.py +++ b/keras/src/layers/input_spec.py @@ -100,6 +100,10 @@ def __repr__(self): ("max_ndim=" + str(self.max_ndim)) if self.max_ndim else "", ("min_ndim=" + str(self.min_ndim)) if self.min_ndim else "", ("axes=" + str(self.axes)) if self.axes else "", + ("optional=" + str(self.optional)) if self.optional else "", + ("allow_last_axis_squeeze=" + str(self.allow_last_axis_squeeze)) + if self.allow_last_axis_squeeze + else "", ] return f"InputSpec({', '.join(x for x in spec if x)})" @@ -167,6 +171,11 @@ def assert_input_compatibility(input_spec, inputs, layer_name): continue if x is None and spec.optional: continue + if x is None and not spec.optional: + raise ValueError( + f"Optional inputs must be declared in the input spec " + f"of a layer. Received input None for {spec}" + ) # Having a shape/dtype is the only commonality of the various # tensor-like objects that may be passed. The most common kind of diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 3aef24f2716c..05ee43c81f6b 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1748,7 +1748,13 @@ def get_shapes_dict(call_spec): continue if k in call_spec.nested_tensor_argument_names: shapes_dict[f"{k}_shape"] = tree.map_structure( - lambda x: backend.standardize_shape(x.shape), v + lambda x: ( + backend.standardize_shape(x.shape) + # Handle optional inputs returning None(s) as shapes + if x is not None + else None + ), + v, ) else: shapes_dict[f"{k}_shape"] = backend.standardize_shape(v.shape) diff --git a/keras/src/models/cloning.py b/keras/src/models/cloning.py index 30bc8940bd4b..ad45340a0aaa 100644 --- a/keras/src/models/cloning.py +++ b/keras/src/models/cloning.py @@ -4,9 +4,10 @@ from keras.src.api_export import keras_export from keras.src.layers import Input from keras.src.layers import InputLayer +from keras.src.layers.core.composite_layer import CompositeLayer from keras.src.models.functional import Functional -from keras.src.models.functional import functional_like_constructor from keras.src.models.sequential import Sequential +from keras.src.ops.function import Function from keras.src.saving import serialization_lib @@ -161,7 +162,7 @@ def call_function(layer, *args, **kwargs): clone_function=clone_function, input_tensors=input_tensors, ) - if isinstance(model, Functional): + if _is_functional(model): # Wrap clone_function to handle recursiveness and layer sharing. clone_function = _wrap_clone_function( clone_function, @@ -170,22 +171,64 @@ def call_function(layer, *args, **kwargs): cache=cache, ) + if isinstance(model, CompositeLayer): + cloned_inputs, cloned_outputs = _clone_function_object( + model._function, + clone_function=clone_function, + call_function=call_function, + input_tensors=input_tensors, + ) + # Create a Keras Function from the graph between inputs and outputs + function = Function( + cloned_inputs, cloned_outputs, model._function.name + ) + + # Create a new CompositeLayer from the cloned function + # Note: A functional subclass of CompositeLayer will be + # cloned as a vanilla CompositeLayer. This could be changed + # in the future to; + # inst = layer.__class__.__new__, then + # CompositeLayer.__init__(inst, function, layer.name) + # It would represent the cloned CompositeLayer with + # the correct class name but not call the __init__ + # method of the subclass which could create problems. + new_model = CompositeLayer(function, model.name) + + # copy manual input spec if any + if hasattr(model, "_manual_input_spec"): + new_model.input_spec = model.input_spec + return new_model + # If the get_config() method is the same as a regular Functional - # model, we're safe to use _clone_functional_model (which relies + # model, we're safe to use _clone_function_object (which relies # on a Functional constructor). In the case where the get_config # is custom, this may not necessarily work, but if clone_function - # or input_tensors are passed, we attempt it anyway + # or call_function or input_tensors are passed, we attempt it anyway # in order to preserve backwards compatibility. if utils.is_default(model.get_config) or ( - clone_function or input_tensors + clone_function or call_function or input_tensors ): - return _clone_functional_model( - model, + cloned_inputs, cloned_outputs = _clone_function_object( + model, # the model is a Function clone_function=clone_function, call_function=call_function, input_tensors=input_tensors, ) + # A subclassed Functional model is always cloned + # as a vanilla Functional model. + new_model = Functional( + cloned_inputs, cloned_outputs, name=model.name + ) + # copy compiled config if any + if model.compiled: + compiled_config = model.get_compile_config() + new_model.compile_from_config(compiled_config) + # copy manual input spec if any + if hasattr(model, "_manual_input_spec"): + new_model.input_spec = model.input_spec + return new_model + # Case of a custom model class if clone_function or input_tensors: raise ValueError( @@ -236,12 +279,13 @@ def wrapped_clone_function(layer): ) cache[id(layer)] = clone return clone - elif isinstance(layer, Functional): + elif _is_functional(layer): clone = clone_model( layer, clone_function=clone_function, call_function=call_function, cache=cache, + recursive=True, ) cache[id(layer)] = clone return clone @@ -334,19 +378,21 @@ def _clone_sequential_model(model, clone_function, input_tensors=None): return cloned_model -def _clone_functional_model( - model, clone_function, input_tensors=None, call_function=None +def _clone_function_object( + function_obj, clone_function, input_tensors=None, call_function=None ): - """Clone a `Functional` model instance. + """Clone a `Function` object instance. - Model cloning is similar to calling a model on new inputs, - except that it creates new layers (and thus new weights) instead - of sharing the weights of the existing layers. + Cloning is similar to calling a Function on new inputs. + Depending on clone_unction and call_function, + layers (and thus weights) can be shared or cloned + (which creates new layers and weights). See 'clone_model' + for details. Input layers are always cloned. Args: - model: Instance of `Functional`. + model: Instance of `Function`. input_tensors: optional list of input tensors to build the model upon. If not provided, placeholders will be created. @@ -354,9 +400,9 @@ def _clone_functional_model( By default, it clones the layer (without copying the weights). Returns: - An instance of `Functional` reproducing the behavior - of the original model, on top of new inputs tensors, - using newly instantiated weights. + New input_tensors, output_tensors which can be used to instantiate + a new `Function` corresponding to the graph of the original function, + with the changes specified by clone_function and call_function. """ if not callable(clone_function): @@ -365,10 +411,16 @@ def _clone_functional_model( f"Received: clone_function={clone_function}" ) - if not isinstance(model, Functional): + if function_obj is None: + raise ValueError( + "The model has no graph of layers. It is probably not built yet. " + "Please build it by calling it on a batch of data before calling " + "clone_model." + ) + if not isinstance(function_obj, Function): raise ValueError( "Expected `model` argument " - f"to be a Functional Model instance. Received: model={model}" + f"to be a Functional Model instance. Received: model={function_obj}" ) if input_tensors is not None: @@ -381,39 +433,33 @@ def _clone_functional_model( f"Received invalid values: inputs_tensors={input_tensors}" ) try: - tree.assert_same_structure(input_tensors, model.input) + tree.assert_same_structure( + input_tensors, function_obj._inputs_struct + ) except ValueError as e: raise ValueError( "`input_tensors` must have the same structure as model.input" - f"\nReference structure: {model.input}" + f"\nReference structure: {function_obj._inputs_struct}" f"\nReceived structure: {input_tensors}" ) from e else: input_tensors = tree.map_structure( lambda x: Input(batch_shape=x.shape, dtype=x.dtype, name=x.name), - model.input, + function_obj._inputs_struct, ) def operation_fn(layer): new_layer = clone_function(layer) return new_layer - output_tensors = model._run_through_graph( + output_tensors = function_obj._run_through_graph( input_tensors, operation_fn=operation_fn, call_fn=call_function, ) - if functional_like_constructor(model.__class__): - new_model = model.__class__( - input_tensors, output_tensors, name=model.name - ) - else: - # This may be incorrect: the new model will end up having a different - # class than the original. However various existing models rely - # on this behavior, so we keep it. - new_model = Functional(input_tensors, output_tensors, name=model.name) - if model.compiled: - compiled_config = model.get_compile_config() - new_model.compile_from_config(compiled_config) - return new_model + return input_tensors, output_tensors + + +def _is_functional(layer): + return isinstance(layer, Functional) or isinstance(layer, CompositeLayer) diff --git a/keras/src/models/cloning_test.py b/keras/src/models/cloning_test.py index b7e576798591..fee5d5d9ba5b 100644 --- a/keras/src/models/cloning_test.py +++ b/keras/src/models/cloning_test.py @@ -2,12 +2,15 @@ import pytest from absl.testing import parameterized +from keras.src import initializers from keras.src import layers from keras.src import models from keras.src import ops from keras.src import testing from keras.src import tree +from keras.src.layers.input_spec import InputSpec from keras.src.models.cloning import clone_model +from keras.src.models.functional import Functional def get_mlp_functional_model(shared_layers=False): @@ -199,6 +202,40 @@ def call_function(layer, *args, **kwargs): ref_input = np.random.random((2, 3)) self.assert_models_equal(model, new_model, ref_input) + def test_call_fn_custom_layer_replace(self): + # alternative dense implementation using the same weights + class AltDense(layers.Layer): + def __init__(self, layer, **kwargs): + super().__init__(**kwargs) + self.dense_layer = layer + + def build(self, input_shape): + self.w = self.dense_layer.kernel + self.b = self.dense_layer.bias + + def call(self, inputs): + result = ops.matmul(inputs, self.w) + self.b + return result + + model = get_mlp_functional_model(shared_layers=False) + + def call_function(layer, *args, **kwargs): + if isinstance(layer, layers.Dense): + new_layer = AltDense(layer) + return new_layer(*args, **kwargs) + else: + return layer(*args, **kwargs) + + new_model = clone_model( + model, + clone_function=lambda x: x, # everything happense in call_function. + call_function=call_function, + ) + self.assertLen(model.layers, 3) + self.assertLen(new_model.layers, 3) + ref_input = np.random.random((2, 3)) + self.assert_models_equal(model, new_model, ref_input) + def test_recursive(self): model = get_nested_functional_model() @@ -243,6 +280,29 @@ def clone_function(layer): self.assertFalse(hasattr(l1, "flag")) self.assertTrue(hasattr(l2, "flag")) + def test_recursive_level_2(self): + inputs = layers.Input(shape=(16, 32)) + outputs = layers.Dense(32, name="dense_2")(inputs) + layer1 = models.Model(inputs, outputs, name="sub") + + inputs = layers.Input(shape=(16, 32)) + outputs = layer1(inputs) + slayer = models.Model(inputs, outputs, name="subfunc") + + inputs = layers.Input(shape=(16, 32)) + outputs = slayer(inputs) + model = models.Model(inputs, outputs) + + def call_fn(layer, *args, **kwargs): + if isinstance(layer, layers.Dense): + new_layer = layers.Dense(layer.units, name="dense_modified") + return new_layer(*args, **kwargs) + return layer(*args, **kwargs) + + new_model = clone_model(model, call_function=call_fn, recursive=True) + sub = new_model.get_layer("subfunc").get_layer("sub") + self.assertEqual(sub.layers[1].name, "dense_modified") + def test_compiled_model_cloning(self): model = models.Sequential() model.add(layers.Input((3,))) @@ -251,3 +311,273 @@ def test_compiled_model_cloning(self): model.compile(optimizer="adam", loss="binary_crossentropy") cloned_model = clone_model(model) self.assertEqual(model.compiled, cloned_model.compiled) + + def test_func_subclass(self): + const_init = initializers.Ones() + zero_init = initializers.Zeros() + + # alternative dense implementation + class AltDense(layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.units = units + + def build(self, input_shape): + self.w = self.add_weight( + shape=(input_shape[-1], self.units), + initializer=const_init, + trainable=True, + ) + self.b = self.add_weight( + shape=(self.units,), + initializer=zero_init, + trainable=True, + ) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + class FuncSubclassModel(models.Model): + def __init__(self, **kwargs): + inputs = layers.Input(shape=(8,)) + y = layers.Dense( + 4, kernel_initializer=const_init, name="original1" + )(inputs) + outputs = layers.Dense( + 8, kernel_initializer=const_init, name="original2" + )(y) + super().__init__(inputs, outputs, **kwargs) + + inputs = layers.Input(shape=(12,)) + y = layers.Dense(8, kernel_initializer=const_init, name="original3")( + inputs + ) + funcsub = FuncSubclassModel() + y = funcsub(y) + outputs = funcsub(y) # reused layer + model = models.Model(inputs, outputs) + + data = np.random.uniform(size=(4, 12)) + model(data) + + def replace_fn(layer, *args, **kwargs): + if isinstance(layer, layers.Dense): + return AltDense(layer.units)(*args, **kwargs) + else: + return layer(*args, **kwargs) # pass-through + + model2 = clone_model( + model, + input_tensors=inputs, + # everything happense in call_function. + clone_function=lambda x: x, + call_function=replace_fn, + recursive=True, + ) + + model2(data) + + # original model is unchanged + for variable in model.variables: + self.assertContainsSubsequence(variable.path, "original") + + # new model has new AltDense layers + for variable in model2.variables: + self.assertContainsSubsequence(variable.path, "alt_dense") + + self.assertEqual(len(model.layers), len(model2.layers)) + for layer1, layer2 in zip(model.layers, model2.layers): + if isinstance(layer1, layers.Dense): + self.assertTrue(layer2.__class__ is AltDense) + # A subclass of Functional is cloned as vanilla Functional for now + # unless it has an explicit functional constructor + elif isinstance(layer1, FuncSubclassModel): + self.assertTrue( + layer2.__class__ is Functional + or layer2.__class__ is FuncSubclassModel + ) + else: + self.assertEqual(layer1.__class__, layer2.__class__) + + self.assertAllClose(model(data), model2(data)) + + def test_parametrized_func_subclass(self): + # alternative dense implementation + class AltDense(layers.Layer): + def __init__(self, units, **kwargs): + super().__init__(**kwargs) + self.units = units + + def build(self, input_shape): + self.w = self.add_weight(shape=(input_shape[-1], self.units)) + self.b = self.add_weight(shape=(self.units,)) + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + class FuncSubclassParametrizedModel(models.Model): + def __init__(self, *args, param=4, **kwargs): + inputs = layers.Input(shape=(8,)) + y = layers.Dense(param)(inputs) + outputs = layers.Dense(param)(y) + super().__init__(inputs, outputs, *args, **kwargs) + self.param = param + + def replace_fn(layer, *args, **kwargs): + if isinstance(layer, layers.Dense): + return AltDense(layer.units)(*args, **kwargs) + else: + return layer(*args, **kwargs) # pass-through + + model = FuncSubclassParametrizedModel(param=11) + self.assertEqual(model.param, 11) + + model2 = clone_model( + model, + clone_function=lambda x: x, + call_function=replace_fn, + recursive=True, + ) + # A subclass of Functional is cloned as vanilla Functional for now + self.assertFalse(model2.__class__ == FuncSubclassParametrizedModel) + self.assertTrue(model2.__class__ == Functional) + # test that the layers were replaced + self.assertTrue(isinstance(model2.layers[0], layers.InputLayer)) + self.assertTrue(isinstance(model2.layers[1], AltDense)) + # Even though the cloned FuncSubclassParametrizedModel is now + # a valilla Functional, test that the underlying AltDense layers + # have the correct param size, as set by the param value. + self.assertEqual(model2.layers[1].w.shape[1], 11) + + def test_clone_passthrough_subfunctional(self): + class SubFunctional(models.Model): + pass + + inputs = layers.Input(shape=(8,)) + y = layers.Dense(4)(inputs) + outputs = layers.Dense(8)(y) + model = SubFunctional(inputs, outputs) + + model2 = clone_model(model) + # cloned as a vanilla Functional + self.assertTrue(model2.__class__ == Functional) + + def test_clone_passthrough_subfunctional_recursive(self): + class SubFunctional(models.Model): + pass + + inputs = layers.Input(shape=(8,)) + outputs = layers.Dense(8)(inputs) + sublayer = SubFunctional(inputs, outputs) + + inputs = layers.Input(shape=(8,)) + outputs = sublayer(inputs) + model = models.Model(inputs, outputs) + + model2 = clone_model(model, recursive=True) + # cloned as a vanilla Functional + self.assertTrue(model2.__class__ == Functional) + self.assertTrue(model2.layers[1].__class__ == Functional) + + def test_clone_functional_subclass(self): + class SubFunctional(models.Model): + def __init__(self, *args, **kwargs): + inputs = layers.Input(shape=(8,)) + outputs = layers.Dense(8)(inputs) + return super().__init__(inputs, outputs, *args, **kwargs) + + model = SubFunctional() + + model2 = clone_model(model) + # cloned as a vanilla Functional + self.assertTrue(model2.__class__ == Functional) + + def test_clone_functional_subclass_non_recursive(self): + class SubFunctional(models.Model): + def __init__(self, *args, **kwargs): + inputs = layers.Input(shape=(8,)) + outputs = layers.Dense(4)(inputs) + return super().__init__(inputs, outputs, *args, **kwargs) + + inputs = layers.Input(shape=(8,)) + outputs = SubFunctional()(inputs) + model = models.Model(inputs, outputs) + + model2 = clone_model(model) + self.assertTrue(model2.__class__ == Functional) + # not touched in non-recursive mode + self.assertTrue(model2.layers[1].__class__ == SubFunctional) + + def test_clone_functional_subclass_recursive(self): + class SubFunctional(models.Model): + def __init__(self, *args, **kwargs): + inputs = layers.Input(shape=(8,)) + outputs = layers.Dense(4)(inputs) + return super().__init__(inputs, outputs, *args, **kwargs) + + inputs = layers.Input(shape=(8,)) + outputs = SubFunctional()(inputs) + model = models.Model(inputs, outputs) + + model2 = clone_model(model, clone_function=lambda x: x, recursive=True) + self.assertTrue(model2.__class__ == Functional) + # cloned as a vanilla Functional + self.assertTrue(model2.layers[1].__class__ == Functional) + + def test_clone_functional_subclass_non_recursive2(self): + class SubFunctional(models.Model): + def __init__(self, *args, **kwargs): + inputs = layers.Input(shape=(8,)) + outputs = layers.Dense(4)(inputs) + return super().__init__(inputs, outputs, *args, **kwargs) + + inputs = layers.Input(shape=(8,)) + outputs = SubFunctional()(inputs) + model = models.Model(inputs, outputs) + + model2 = clone_model(model, clone_function=lambda x: x, recursive=False) + self.assertTrue(model2.__class__ == Functional) + # not touched in non-recursive mode + self.assertTrue(model2.layers[1].__class__ == SubFunctional) + + def test_clone_passthrough_subfunctional_with_params(self): + class SubFunctional(models.Model): + def __init__(self, inputs, outputs, param, *args, **kwargs): + super().__init__(inputs, outputs, *args, **kwargs) + self.param = param + + inputs = layers.Input(shape=(8,)) + y = layers.Dense(4)(inputs) + outputs = layers.Dense(8)(y) + model = SubFunctional(inputs, outputs, 8) + + # cloned as a vanilla Functional + model2 = clone_model(model) + self.assertTrue(model2.__class__ == Functional) + + def test_clone_with_input_spec(self): + def layer_fn(inputs): + return layers.Dense(12)(inputs) + + inputs = layers.Input(shape=(12,)) + model1 = models.Model(inputs, layer_fn(inputs)) + model1.input_spec = InputSpec(shape=(None, 12), optional=True) + + def layer_fn2(inputs): + x = layers.Dense(12)(inputs) + return model1(x) + + inputs2 = layers.Input(shape=(12,)) + model2 = models.Model(inputs2, layer_fn2(inputs2)) + model2.input_spec = InputSpec(shape=(None, 12), optional=True) + + model2(ops.ones(shape=(8,12))) + + def call_fn(layer, *args, **kwargs): + return layer(*args, **kwargs) + + new_model = clone_model(model2, clone_function=lambda x:x, + call_function=call_fn, + recursive=True) + self.assertEqual(new_model.input_spec, model2.input_spec) + self.assertEqual(new_model.layers[2].input_spec, model1.input_spec) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index d98318b5dd40..a28631c97790 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -1,5 +1,4 @@ import copy -import inspect import typing import warnings @@ -14,6 +13,7 @@ from keras.src.legacy.saving import saving_utils from keras.src.legacy.saving import serialization as legacy_serialization from keras.src.models.model import Model +from keras.src.models.model import strict_functional_like_constructor from keras.src.ops.function import Function from keras.src.ops.function import _build_map from keras.src.ops.function import make_node_key @@ -171,19 +171,9 @@ def layers(self, _): ) def call(self, inputs, training=None, mask=None): - # Add support for training, masking - inputs = self._standardize_inputs(inputs) - if mask is None: - masks = [None] * len(inputs) - else: - masks = tree.flatten(mask) - for x, mask in zip(inputs, masks): - if mask is not None: - backend.set_keras_mask(x, mask) - outputs = self._run_through_graph( - inputs, operation_fn=lambda op: operation_fn(op, training=training) + return run_through_graph_with_training_and_mask( + self, inputs, training, mask ) - return unpack_singleton(outputs) def compute_output_spec(self, inputs, training=None, mask=None): # From Function @@ -213,114 +203,6 @@ def output_shape(self): def _assert_input_compatibility(self, *args): return super(Model, self)._assert_input_compatibility(*args) - def _maybe_warn_inputs_struct_mismatch(self, inputs, raise_exception=False): - try: - # We first normalize to tuples before performing the check to - # suppress warnings when encountering mismatched tuples and lists. - tree.assert_same_structure( - tree.lists_to_tuples(inputs), - tree.lists_to_tuples(self._inputs_struct), - ) - except: - model_inputs_struct = tree.map_structure( - lambda x: x.name, self._inputs_struct - ) - inputs_struct = tree.map_structure( - lambda x: f"Tensor(shape={x.shape})", inputs - ) - msg = ( - "The structure of `inputs` doesn't match the expected " - f"structure.\nExpected: {model_inputs_struct}\n" - f"Received: inputs={inputs_struct}" - ) - if raise_exception: - raise ValueError(msg) - warnings.warn(msg) - - def _convert_inputs_to_tensors(self, flat_inputs): - converted = [] - for x, input in zip(flat_inputs, self._inputs): - if x is None: # TODO: check if optional - converted.append(x) - else: - converted.append( - ops.convert_to_tensor( - x, dtype=input.dtype, sparse=input.sparse - ) - ) - return converted - - def _adjust_input_rank(self, flat_inputs): - flat_ref_shapes = [x.shape for x in self._inputs] - adjusted = [] - for x, ref_shape in zip(flat_inputs, flat_ref_shapes): - if x is None: - adjusted.append(x) - continue - x_rank = len(x.shape) - ref_rank = len(ref_shape) - if x_rank == ref_rank: - adjusted.append(x) - continue - if x_rank == ref_rank + 1: - if x.shape[-1] == 1: - adjusted.append(ops.squeeze(x, axis=-1)) - continue - if x_rank == ref_rank - 1: - if ref_shape[-1] == 1: - adjusted.append(ops.expand_dims(x, axis=-1)) - continue - raise ValueError( - f"Invalid input shape for input {x}. Expected shape " - f"{ref_shape}, but input has incompatible shape {x.shape}" - ) - # Add back metadata. - for i in range(len(flat_inputs)): - if hasattr(flat_inputs[i], "_keras_history"): - adjusted[i]._keras_history = flat_inputs[i]._keras_history - mask = backend.get_keras_mask(flat_inputs[i]) - if mask is not None: - backend.set_keras_mask(adjusted[i], mask) - return adjusted - - def _standardize_inputs(self, inputs): - raise_exception = False - if isinstance(inputs, dict) and not isinstance( - self._inputs_struct, dict - ): - # This is to avoid warning - # when we have reconciable dict/list structs - if hasattr(self._inputs_struct, "__len__") and all( - isinstance(i, backend.KerasTensor) for i in self._inputs_struct - ): - expected_keys = set(i.name for i in self._inputs_struct) - keys = set(inputs.keys()) - if expected_keys.issubset(keys): - inputs = [inputs[i.name] for i in self._inputs_struct] - else: - raise_exception = True - elif isinstance(self._inputs_struct, backend.KerasTensor): - if self._inputs_struct.name in inputs: - inputs = [inputs[self._inputs_struct.name]] - else: - raise_exception = True - else: - raise_exception = True - if ( - isinstance(self._inputs_struct, dict) - and not isinstance(inputs, dict) - and list(self._inputs_struct.keys()) - != sorted(self._inputs_struct.keys()) - ): - raise_exception = True - self._maybe_warn_inputs_struct_mismatch( - inputs, raise_exception=raise_exception - ) - - flat_inputs = tree.flatten(inputs) - flat_inputs = self._convert_inputs_to_tensors(flat_inputs) - return self._adjust_input_rank(flat_inputs) - @property def input(self): # For backwards compatibility, @@ -341,124 +223,308 @@ def input_spec(self): if hasattr(self, "_manual_input_spec"): return self._manual_input_spec - def shape_with_no_batch_size(x): - x = list(x) - if x: - x[0] = None - return tuple(x) - - def make_spec_for_tensor(x, name=None): - optional = False - if isinstance(x._keras_history[0], InputLayer): - if x._keras_history[0].optional: - optional = True - return InputSpec( - shape=shape_with_no_batch_size(x.shape), - allow_last_axis_squeeze=True, - name=x._keras_history[0].name if name is None else name, - optional=optional, - ) - - if isinstance(self._inputs_struct, dict): - if all( - isinstance(x, backend.KerasTensor) - for x in self._inputs_struct.values() - ): - # Case where `_nested_inputs` is a plain dict of Inputs. - names = sorted(self._inputs_struct.keys()) - return [ - make_spec_for_tensor(self._inputs_struct[name], name=name) - for name in names - ] - return None # Deeply nested dict: skip checks. - return [make_spec_for_tensor(x) for x in self.inputs] + return compute_input_spec(self._inputs_struct, self._inputs) @input_spec.setter def input_spec(self, value): self._manual_input_spec = value def get_config(self): - if not functional_like_constructor(self.__class__): + if not strict_functional_like_constructor(self.__class__): # Subclassed networks are not serializable # (unless serialization is implemented by - # the author of the subclassed network). + # the author of the subclassed network) + # and unless the author has implemented a + # Functional-like constructor. return Model.get_config(self) - config = { - "name": self.name, - "trainable": self.trainable, - } - # Build a map from a layer unique name (make_node_key) - # to the index of the nodes that are saved in the config. - # Only nodes in network_nodes are saved. - node_reindexing_map = {} - for operation in self.operations: - if issubclass(operation.__class__, Functional): - # Functional models start with a pre-existing node - # linking their input to output. - kept_nodes = 1 + return serialize_functional_config(self, self) + + +def compute_input_spec(inputs_struct, inputs): + """Compute the input spec for a Function-based layer or Model.""" + + def shape_with_no_batch_size(x): + x = list(x) + if x: + x[0] = None + return tuple(x) + + def make_spec_for_tensor(x, name=None): + optional = False + if isinstance(x._keras_history[0], InputLayer): + if x._keras_history[0].optional: + optional = True + return InputSpec( + shape=shape_with_no_batch_size(x.shape), + allow_last_axis_squeeze=True, + name=x._keras_history[0].name if name is None else name, + optional=optional, + ) + + if isinstance(inputs_struct, dict): + if all( + isinstance(x, backend.KerasTensor) for x in inputs_struct.values() + ): + # Case where `_nested_inputs` is a plain dict of Inputs. + names = sorted(inputs_struct.keys()) + return [ + make_spec_for_tensor(inputs_struct[name], name=name) + for name in names + ] + return None # Deeply nested dict: skip checks. + computed_spec = [make_spec_for_tensor(x) for x in inputs] + return unpack_singleton(computed_spec) + + +def run_through_graph_with_training_and_mask( + function_obj, inputs, training=None, mask=None +): + """Run inputs through a Function graph, injecting training and mask arks. + + Args: + function_obj: Function object to execute + inputs: Input values + training: Training mode boolean + mask: Input mask values + + Returns: + Output values from the Function + """ + inputs = _standardize_inputs(function_obj, inputs) + + # Set masks if provided + if mask is not None: + masks = tree.flatten(mask) + for x, m in zip(inputs, masks): + if m is not None and x is not None: + from keras.src.backend import set_keras_mask + + set_keras_mask(x, m) + + # Run the function with modified operation_fn that handles training + outputs = function_obj._run_through_graph( + inputs, + operation_fn=lambda op: operation_fn_with_training( + op, training=training + ), + ) + + # Unpack singleton result if needed (for API consistency) + return unpack_singleton(outputs) + + +def _adjust_input_rank(function_obj, flat_inputs): + flat_ref_shapes = [x.shape for x in function_obj._inputs] + adjusted = [] + for x, ref_shape in zip(flat_inputs, flat_ref_shapes): + if x is None: + adjusted.append(x) + continue + x_rank = len(x.shape) + ref_rank = len(ref_shape) + if x_rank == ref_rank: + adjusted.append(x) + continue + if x_rank == ref_rank + 1: + if x.shape[-1] == 1: + adjusted.append(ops.squeeze(x, axis=-1)) + continue + if x_rank == ref_rank - 1: + if ref_shape[-1] == 1: + adjusted.append(ops.expand_dims(x, axis=-1)) + continue + raise ValueError( + f"Invalid input shape for input {x}. Expected shape " + f"{ref_shape}, but input has incompatible shape {x.shape}" + ) + + # Add back metadata. + for i in range(len(flat_inputs)): + if hasattr(flat_inputs[i], "_keras_history"): + adjusted[i]._keras_history = flat_inputs[i]._keras_history + mask = backend.get_keras_mask(flat_inputs[i]) + if mask is not None: + backend.set_keras_mask(adjusted[i], mask) + + return adjusted + + +def _convert_inputs_to_tensors(function_obj, flat_inputs): + converted = [] + for x, input in zip(flat_inputs, function_obj._inputs): + if x is None: # TODO: check if optional + converted.append(x) + else: + converted.append( + ops.convert_to_tensor(x, dtype=input.dtype, sparse=input.sparse) + ) + return converted + + +def _maybe_warn_inputs_struct_mismatch( + function_obj, inputs, raise_exception=False +): + try: + # We first normalize to tuples before performing the check to + # suppress warnings when encountering mismatched tuples and lists. + tree.assert_same_structure( + tree.lists_to_tuples(inputs), + tree.lists_to_tuples(function_obj._inputs_struct), + ) + except: + model_inputs_struct = tree.map_structure( + lambda x: x.name, function_obj._inputs_struct + ) + inputs_struct = tree.map_structure( + lambda x: f"Tensor(shape={x.shape})", inputs + ) + msg = ( + "The structure of `inputs` doesn't match the expected " + f"structure.\nExpected: {model_inputs_struct}\n" + f"Received: inputs={inputs_struct}" + ) + if raise_exception: + raise ValueError(msg) + warnings.warn(msg) + + +def _standardize_inputs(function_obj, inputs): + """Convert and validate inputs to match function's expected inputs. + + Args: + function_obj: Function object with _inputs_struct and _inputs attributes + inputs: Input values to standardize + + Returns: + Standardized inputs as a flat list + """ + raise_exception = False + if isinstance(inputs, dict) and not isinstance( + function_obj._inputs_struct, dict + ): + # This is to avoid warning when we have reconciable dict/list structs + if hasattr(function_obj._inputs_struct, "__len__") and all( + isinstance(i, backend.KerasTensor) + for i in function_obj._inputs_struct + ): + expected_keys = set(i.name for i in function_obj._inputs_struct) + keys = set(inputs.keys()) + if expected_keys.issubset(keys): + inputs = [inputs[i.name] for i in function_obj._inputs_struct] else: - kept_nodes = 0 - for original_node_index, node in enumerate( - operation._inbound_nodes - ): - node_key = make_node_key(operation, original_node_index) - if node_key in self._nodes: - # i.e. we mark it to be saved - node_reindexing_map[node_key] = kept_nodes - kept_nodes += 1 - - # serialize and save the layers in layer_configs - layer_configs = [] - for operation in self.operations: # From the earliest layers on. - filtered_inbound_nodes = [] - for original_node_index, node in enumerate( - operation._inbound_nodes - ): - node_key = make_node_key(operation, original_node_index) - if node_key in self._nodes: - # The node is relevant to the model: - # add to filtered_inbound_nodes. - node_data = serialize_node(node, own_nodes=self._nodes) - if node_data is not None: - filtered_inbound_nodes.append(node_data) - - serialize_obj_fn = serialization_lib.serialize_keras_object - if global_state.get_global_attribute("use_legacy_config", False): - # Legacy format serialization used for H5 and SavedModel - serialize_obj_fn = legacy_serialization.serialize_keras_object - layer_config = serialize_obj_fn(operation) - layer_config["name"] = operation.name - layer_config["inbound_nodes"] = filtered_inbound_nodes - layer_configs.append(layer_config) - config["layers"] = layer_configs - - # Gather info about inputs and outputs. - def get_tensor_config(tensor): - operation = tensor._keras_history[0] - node_index = tensor._keras_history[1] - tensor_index = tensor._keras_history[2] - node_key = make_node_key(operation, node_index) - assert node_key in self._nodes - new_node_index = node_reindexing_map[node_key] - return [operation.name, new_node_index, tensor_index] - - def map_tensors(tensors): - if isinstance(tensors, backend.KerasTensor): - return [get_tensor_config(tensors)] - return tree.map_structure(get_tensor_config, tensors) - - config["input_layers"] = map_tensors(self._inputs_struct) - config["output_layers"] = map_tensors(self._outputs_struct) - return copy.deepcopy(config) - - -def functional_from_config(cls, config, custom_objects=None): - """Instantiates a Functional model from its config (from `get_config()`). + raise_exception = True + elif isinstance(function_obj._inputs_struct, backend.KerasTensor): + if function_obj._inputs_struct.name in inputs: + inputs = [inputs[function_obj._inputs_struct.name]] + else: + raise_exception = True + else: + raise_exception = True + + if ( + isinstance(function_obj._inputs_struct, dict) + and not isinstance(inputs, dict) + and list(function_obj._inputs_struct.keys()) + != sorted(function_obj._inputs_struct.keys()) + ): + raise_exception = True + + _maybe_warn_inputs_struct_mismatch( + function_obj, inputs, raise_exception=raise_exception + ) + + flat_inputs = tree.flatten(inputs) + flat_inputs = _convert_inputs_to_tensors(function_obj, flat_inputs) + return _adjust_input_rank(function_obj, flat_inputs) + + +def serialize_functional_config(obj, function_obj): + """Serialize a Function-based object to config. + + Args: + obj: A Functional model or a CompositeLayer instance. + function_obj: its Function instance. + + Returns: + Dictionary containing serialized configuration. + """ + config = { + "name": obj.name, + "trainable": obj.trainable, + } + + # Build a map from a layer unique name (make_node_key) + # to the index of the nodes that are saved in the config. + # Only nodes in nodes parameter are saved. + node_reindexing_map = {} + for operation in function_obj.operations: + if issubclass(operation.__class__, Function): + # Functional models start with a pre-existing node + # linking their input to output. + kept_nodes = 1 + else: + kept_nodes = 0 + for original_node_index, node in enumerate(operation._inbound_nodes): + node_key = make_node_key(operation, original_node_index) + if node_key in function_obj._nodes: + # i.e. we mark it to be saved + node_reindexing_map[node_key] = kept_nodes + kept_nodes += 1 + + # serialize and save the layers in layer_configs + layer_configs = [] + for operation in function_obj.operations: # From the earliest layers on. + filtered_inbound_nodes = [] + for original_node_index, node in enumerate(operation._inbound_nodes): + node_key = make_node_key(operation, original_node_index) + if node_key in function_obj._nodes: + # The node is relevant to the model: + # add to filtered_inbound_nodes. + node_data = serialize_node(node, own_nodes=function_obj._nodes) + if node_data is not None: + filtered_inbound_nodes.append(node_data) + + serialize_obj_fn = serialization_lib.serialize_keras_object + if global_state.get_global_attribute("use_legacy_config", False): + # Legacy format serialization used for H5 and SavedModel + serialize_obj_fn = legacy_serialization.serialize_keras_object + layer_config = serialize_obj_fn(operation) + layer_config["name"] = operation.name + layer_config["inbound_nodes"] = filtered_inbound_nodes + layer_configs.append(layer_config) + config["layers"] = layer_configs + + # Gather info about inputs and outputs. + def get_tensor_config(tensor): + operation = tensor._keras_history[0] + node_index = tensor._keras_history[1] + tensor_index = tensor._keras_history[2] + node_key = make_node_key(operation, node_index) + assert node_key in function_obj._nodes + new_node_index = node_reindexing_map[node_key] + return [operation.name, new_node_index, tensor_index] + + def map_tensors(tensors): + if isinstance(tensors, backend.KerasTensor): + return get_tensor_config(tensors) + return tree.map_structure(get_tensor_config, tensors) + + config["input_layers"] = map_tensors(function_obj._inputs_struct) + config["output_layers"] = map_tensors(function_obj._outputs_struct) + return copy.deepcopy(config) + + +def function_from_config(cls, config, custom_objects=None): + """Instantiates a Function from its config. + A valid config for this can be saved by: + - a model: Functional.get_config() + - a Finction-based layer: CompositeLayer.get_config() Args: - cls: Class of the model, e.g. a custom subclass of `Model`. - config: Output of `get_config()` for the original model instance. + cls: Class to restore, e.g. a custom subclass of `Model` or + CompositeLayer. + config: Output of `get_config()` for the original instance. custom_objects: Optional dict of custom objects. Returns: @@ -535,16 +601,11 @@ def process_layer(layer_data): add_unprocessed_node(layer, node_data) # Extract config used to instantiate Functional model from the config. The - # remaining config will be passed as keyword arguments to the Model - # constructor. + # remaining config will be passed as keyword arguments to the Functional + # Model or Function constructor. functional_config = {} for key in ["layers", "input_layers", "output_layers"]: functional_config[key] = config.pop(key) - for key in ["name", "trainable"]: - if key in config: - functional_config[key] = config.pop(key) - else: - functional_config[key] = None # First, we create all layers and enqueue nodes to be processed for layer_data in functional_config["layers"]: @@ -584,13 +645,11 @@ def process_layer(layer_data): del unprocessed_nodes[layer] # Create list of input and output tensors and return new class - name = functional_config["name"] - trainable = functional_config["trainable"] def get_tensor(layer_name, node_index, tensor_index): assert layer_name in created_layers layer = created_layers[layer_name] - if isinstance(layer, Functional): + if isinstance(layer, Function): # Functional models start out with a built-in node. node_index -= 1 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors @@ -618,13 +677,11 @@ def map_tensors(tensors): return cls( inputs=input_tensors, outputs=output_tensors, - name=name, - trainable=trainable, - **config, + **config, # "name" or "trainable" are in the config if relevant ) -def operation_fn(operation, training): +def operation_fn_with_training(operation, training): def call(*args, **kwargs): if ( hasattr(operation, "_call_has_training_arg") @@ -637,14 +694,6 @@ def call(*args, **kwargs): return call -def functional_like_constructor(cls): - init_args = inspect.getfullargspec(cls.__init__).args[1:] - functional_init_args = inspect.getfullargspec(Functional.__init__).args[1:] - if init_args == functional_init_args: - return True - return False - - def unpack_singleton(x): if isinstance(x, (list, tuple)) and len(x) == 1: return x[0] diff --git a/keras/src/models/functional_test.py b/keras/src/models/functional_test.py index 5aabc5094145..e6f0327901e6 100644 --- a/keras/src/models/functional_test.py +++ b/keras/src/models/functional_test.py @@ -7,6 +7,7 @@ from keras.src import applications from keras.src import backend +from keras.src import initializers from keras.src import layers from keras.src import ops from keras.src import saving @@ -261,8 +262,9 @@ def test_restored_multi_output_type(self, out_type): outputs = {"a": output_a, "b": output_b} else: outputs = out_type([output_a, output_b]) - model = Functional(inputs, outputs) + model = Functional(inputs, outputs, name="funcmodel") model_restored = Functional.from_config(model.get_config()) + self.assertEqual(model_restored.name, model.name) # Eager call in_val = np.random.random((2, 3)) @@ -738,6 +740,11 @@ def test_list_input_with_dict_build(self): x1 = Input((10,), name="IT") x2 = Input((10,), name="IS") y = layers.subtract([x1, x2]) + # Note: the test fails here only because the order of dict + # keys "IT", "IS" is different from the sorted order of the + # keys "IS", "IT". Otherwise, passing a list of inputs to + # a model expecting a dictionary of inputs seems to be allowed. + # as long as flattening the dict does not result in reordering. model = Model(inputs={"IT": x1, "IS": x2}, outputs=y) x1 = ops.ones((1, 10)) x2 = ops.zeros((1, 10)) @@ -748,3 +755,109 @@ def test_list_input_with_dict_build(self): "The structure of `inputs` doesn't match the expected structure", ): model([x1, x2]) + + def test_functional_subclass_serialization(self): + class FuncSubclass(Functional): + def __init__(self, name=None, **kwargs): + inputs = Input((4,), name="input") + y = layers.Dense(8)(inputs) + outputs = layers.Dense(4)(y) + super().__init__(inputs, outputs, name, **kwargs) + + model = FuncSubclass() + data = ops.ones((8, 4)) + output1 = model(data) # build the model + temp_filepath = os.path.join(self.get_temp_dir(), "func_subclass.keras") + model.save(temp_filepath) + + # Note: this recreates the model by calling FuncSubclass.__init__ + # and does *not* test the functional.function_from_config method. + + loaded_model = saving.load_model( + temp_filepath, custom_objects={"FuncSubclass": FuncSubclass} + ) + + output2 = loaded_model(data) + self.assertAllClose(output1, output2) + + def test_functional_subclass_cfg_serialization(self): + class FuncSubclass(Model): + def __init__(self, name=None, **kwargs): + inputs = Input((4,), name="input") + y = layers.Dense(8)(inputs) + outputs = layers.Dense(4)(y) + # during deserilization, **kwargs already + # has "inputs" and "outputs" + super().__init__(inputs, outputs, name=name, **kwargs) + + model = FuncSubclass() + data = ops.ones((8, 4)) + model(data) # build the model + + # Note: this recreates the model by calling FuncSubclass.__init__ + # and does *not* test the functional.function_from_config method. + + config = model.get_config() + loaded_model = FuncSubclass.from_config(config) + + # check the model works (weights were not saved + # so outputs cannot be compared) + loaded_model(data) + + def test_functional_in_functional_with_reuse_serialization(self): + ini = initializers.Ones() + + # sub-functional + inputs = Input((8,)) + y = layers.Dense(6, kernel_initializer=ini)(inputs) + outputs = layers.Dense(8, kernel_initializer=ini)(y) + sub_model1 = Model(inputs, outputs, name="f1") + + inputs = Input((8,)) + y1 = sub_model1(inputs) + outputs = sub_model1(y1) # reuse + + model = Model(inputs, outputs) + data = ops.ones((2, 8)) + output1 = model(data) # build the model + + # this recreates the model from the saved functional graph + # and *does* test the functional.function_from_config method. + config = model.get_config() + loaded_model = Model.from_config(config) + + # check the model works + output2 = loaded_model(data) + # check both models return the same results + # (weights were initialized deterministically) + self.assertAllClose(output1, output2) + + def test_functional_in_functional_with_reuse_saving(self): + # sub-functional + inputs = Input((8,)) + y = layers.Dense(6)(inputs) + outputs = layers.Dense(8)(y) + sub_model1 = Model(inputs, outputs, name="f1") + + inputs = Input((8,)) + y1 = sub_model1(inputs) + outputs = sub_model1(y1) # reuse + + model = Model(inputs, outputs) + data = ops.ones((2, 8)) + output1 = model(data) # build the model + + # this recreates the model from the saved functional graph + # and *does* test the functional.function_from_config method. + + temp_filepath = os.path.join( + self.get_temp_dir(), "func_subclass_reuse.keras" + ) + model.save(temp_filepath) + loaded_model = saving.load_model(temp_filepath) + + # check the model works + output2 = loaded_model(data) + # check both models return the same results + # (weights were initialized deterministically) + self.assertAllClose(output1, output2) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index d80a9ecada5b..0d6e2d2dc5f8 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -634,21 +634,17 @@ def from_config(cls, config, custom_objects=None): is_functional_config = all( key in config for key in functional_config_keys ) - argspec = inspect.getfullargspec(cls.__init__) - functional_init_args = inspect.getfullargspec(Functional.__init__).args[ - 1: - ] - revivable_as_functional = ( - cls in {Functional, Model} - or argspec.args[1:] == functional_init_args - or (argspec.varargs == "args" and argspec.varkw == "kwargs") - ) + revivable_as_functional = cls in { + Functional, + Model, + } or functional_like_constructor(cls) if is_functional_config and revivable_as_functional: # Revive Functional model - # (but not Functional subclasses with a custom __init__) - from keras.src.models.functional import functional_from_config + # (but not Functional subclasses with + # a custom non-functional __init__) + from keras.src.models.functional import function_from_config - return functional_from_config( + return function_from_config( cls, config, custom_objects=custom_objects ) @@ -880,13 +876,40 @@ def model_from_json(json_string, custom_objects=None): def functional_init_arguments(args, kwargs): + # This test is permissive. Any argument combination that + # could be a Functional init is allowed. This test will be + # followed by an actual call of the Functional constructor + # so the worst case is that args are not what they should + # be and the constructor fails with an explicit error message. return ( - (len(args) == 2) + (len(args) >= 2) or (len(args) == 1 and "outputs" in kwargs) or ("inputs" in kwargs and "outputs" in kwargs) ) +def functional_like_constructor(cls): + # This test is permissive. Any constructor that could be passed + # inputs and outputs is accepted. This test triggers Functional + # deserialization when whe know we have a functional config so + # it's OK to try anything that could work. + init_args = inspect.signature(cls.__init__).parameters + funct_init_args = ("inputs" in init_args and "outputs" in init_args) or ( + "args" in init_args or "kwargs" in init_args + ) + return funct_init_args + + +def strict_functional_like_constructor(cls): + # This test is conservative. Only explcit "inputs" and "outputs" + # arguments with those names, are accepted. This test triggers Functional + # serialization and we want to do that in a subclass only when an explicitly + # functional __init__(inputs, outputs) constructor exists in the subclass. + init_args = inspect.signature(cls.__init__).parameters + funct_init_args = "inputs" in init_args and "outputs" in init_args + return funct_init_args + + def inject_functional_model_class(cls): """Inject `Functional` into the hierarchy of this class if needed.""" from keras.src.models import functional diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 6ed7d3c6543e..3288bb8dbea9 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -9,6 +9,7 @@ from keras.src import backend from keras.src import layers from keras.src import losses +from keras.src import ops from keras.src import testing from keras.src import tree from keras.src.layers.core.input_layer import Input @@ -164,6 +165,11 @@ def _get_variable_value_by_path(variables, path): raise ValueError(f"No variable was find with path = {path}") +def has_functional_config_keys(config): + functional_config_keys = ["name", "layers", "input_layers", "output_layers"] + return all(key in config for key in functional_config_keys) + + @pytest.mark.requires_trainable_backend class ModelTest(testing.TestCase): def test_functional_rerouting(self): @@ -216,8 +222,8 @@ def call(self, x): def test_reviving_functional_from_config_custom_model(self): class CustomModel(Model): - def __init__(self, *args, param=1, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, inputs, outputs, *args, param=1, **kwargs): + super().__init__(inputs, outputs, *args, **kwargs) self.param = param def get_config(self): @@ -232,6 +238,244 @@ def get_config(self): new_model = CustomModel.from_config(model.get_config()) self.assertEqual(new_model.param, 3) + def test_reviving_functional_from_config_custom_model0(self): + # Functional custom model, true Functional-like constructor + # CustomModel.__init__ called with true functional-like args + # super().__init__ called with true functional-like args + class CustomModel0(Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + inputs = layers.Input((3,)) + outputs = layers.Dense(5)(inputs) + model = CustomModel0(inputs=inputs, outputs=outputs) + + self.assertTrue(isinstance(model, CustomModel0)) + self.assertTrue(isinstance(model, Functional)) + # No way to detect that this can be serialized functionnally + # since the graph could have been created inside the custom + # __init__ with the same __init__ args. + config = model.get_config() + self.assertFalse(has_functional_config_keys(config)) + with self.assertRaisesRegex(TypeError, "Unable to revive model"): + CustomModel0.from_config(config) + + # Same thing when inputs and outputs are initially + # passed as args rather than kwargs. + model = CustomModel0(inputs, outputs) + config = model.get_config() + self.assertFalse(has_functional_config_keys(config)) + with self.assertRaisesRegex(TypeError, "Unable to revive model"): + CustomModel0.from_config(config) + + def test_reviving_functional_from_config_custom_model1(self): + # Functional custom model, true Functional-like constructor + # CustomModel.__init__ called with true functional-like args + # super().__init__ called with true functional-like args + class CustomModel1(Model): + def __init__(self, *args, param=1, **kwargs): + super().__init__(*args, **kwargs) + self.param = param + + def get_config(self): + base_config = super().get_config() + config = {"param": self.param} + return base_config | config + + inputs = layers.Input((3,)) + outputs = layers.Dense(5)(inputs) + model = CustomModel1(inputs=inputs, outputs=outputs, param=3) + + self.assertEqual(model.param, 3) + self.assertTrue(isinstance(model, CustomModel1)) + self.assertTrue(isinstance(model, Functional)) + # No way to detect that this can be serialized functionnally + # since the graph could have been created inside the custom + # __init__ with the same __init__ args. + config = model.get_config() + self.assertFalse(has_functional_config_keys(config)) + with self.assertRaisesRegex(TypeError, "Unable to revive model"): + CustomModel1.from_config(config) + + def test_reviving_functional_from_config_custom_model2(self): + # Functional custom model, true Functional-like constructor + # CustomModel.__init__ called with true functional-like args + # super().__init__ called with true functional-like args + # Explicit 'inputs' and 'outputs' args. + class CustomModel2(Model): + def __init__(self, inputs, outputs, *args, param=1, **kwargs): + super().__init__(inputs, outputs, *args, **kwargs) + self.param = param + + def get_config(self): + base_config = super().get_config() + config = {"param": self.param} + return base_config | config + + inputs = layers.Input((3,)) + outputs = layers.Dense(5)(inputs) + model = CustomModel2(inputs=inputs, outputs=outputs, param=3) + + self.assertEqual(model.param, 3) + self.assertTrue(isinstance(model, CustomModel2)) + self.assertTrue(isinstance(model, Functional)) + # Here, the model has an explicit Functional-like constructor + # with inputs and outputs arguments: we detect that it can be + # serialized functionally + config = model.get_config() + self.assertTrue(has_functional_config_keys(config)) + new_model = CustomModel2.from_config(config) + self.assertEqual(new_model.param, 3) + self.assertTrue(isinstance(new_model.layers[0], layers.InputLayer)) + self.assertTrue(isinstance(new_model.layers[1], layers.Dense)) + + # However, if the inputs and outputs parameters in the constructr are + # not called that, there is no way to detect the functional constructor + # so we conservatively decline to serialize functionally. + class CustomModel2B(CustomModel2): + def __init__(self, a, b, *args, param=1, **kwargs): + super().__init__(a, b, *args, param=param, **kwargs) + self.param = param + + model = CustomModel2B(inputs, outputs, param=3) + self.assertEqual(model.param, 3) + self.assertTrue(isinstance(model, CustomModel2B)) + self.assertTrue(isinstance(model, Functional)) + config = model.get_config() + self.assertFalse(has_functional_config_keys(config)) + with self.assertRaisesRegex(TypeError, "Unable to revive model"): + CustomModel2B.from_config(config) + + def test_reviving_functional_from_config_custom_model3(self): + # Functional custom model, false Functional-like constructor + # CustomModel.__init__ called with false functional-like args + # super().__init__ called with true functional-like args + class CustomModel3(Model): + def __init__(self, **kwargs): + inputs = layers.Input((5,)) + outputs = layers.Dense(12)(inputs) + super().__init__(inputs, outputs, **kwargs) + + model = CustomModel3() + self.assertTrue(isinstance(model, CustomModel3)) + self.assertTrue(isinstance(model, Functional)) + config = model.get_config() + self.assertFalse(has_functional_config_keys(config)) + # Model can still be revived by calling its constructor + new_model = CustomModel3.from_config(config) + self.assertTrue(isinstance(new_model.layers[0], layers.InputLayer)) + self.assertTrue(isinstance(new_model.layers[1], layers.Dense)) + + def test_reviving_functional_from_config_custom_model4(self): + # Functional custom model, false Functional-like constructor + # CustomModel.__init__ called with false functional-like args + # super().__init__ called with true functional-like args + class CustomModel4(Model): + def __init__(self, a, b, *args, **kwargs): + inputs = layers.Input((a,)) + outputs = layers.Dense(b)(inputs) + (self.a, self.b) = (a, b) + super().__init__(inputs, outputs, *args, **kwargs) + + def get_config(self): + base_config = super().get_config() + config = {"a": self.a, "b": self.b} + return base_config | config + + model = CustomModel4(5, 12) + self.assertEqual(model.a, 5) + self.assertEqual(model.b, 12) + self.assertTrue(isinstance(model, CustomModel4)) + self.assertTrue(isinstance(model, Functional)) + config = model.get_config() + self.assertFalse(has_functional_config_keys(config)) + # Model can still be revived by calling its constructor + new_model = CustomModel4.from_config(config) + self.assertEqual(new_model.a, 5) + self.assertEqual(new_model.b, 12) + + def test_reviving_functional_from_config_custom_model5(self): + # non-functional custom model + # CustomModel.__init__ called with false functional-like args + # super().__init__ called with non-functional-like args + class CustomModel5(Model): + def __init__(self, a, b, *args, **kwargs): + super().__init__(*args, **kwargs) + self.w = self.add_weight(shape=(a, b)) + (self.a, self.b) = (a, b) + + def call(self, inputs): + result = ops.matmul(inputs, self.w) + return result + + def get_config(self): + base_config = super().get_config() + config = {"a": self.a, "b": self.b} + return base_config | config + + model = CustomModel5(5, 12) + self.assertEqual(model.a, 5) + self.assertEqual(model.b, 12) + self.assertTrue(isinstance(model, CustomModel5)) + self.assertFalse(isinstance(model, Functional)) + config = model.get_config() + self.assertFalse(has_functional_config_keys(config)) + # Model can still be revived by calling its constructor + new_model = CustomModel5.from_config(config) + self.assertEqual(new_model.a, 5) + self.assertEqual(new_model.b, 12) + + def test_reviving_functional_from_config_custom_model6(self): + # non-functional custom model + # CustomModel.__init__ called with false functional-like args + # super().__init__ called with non-functional-like args + class CustomModel6(Model): + def __init__(self, a, b): + super().__init__() # will fail when "name" is passed + # during deserialization + self.w = self.add_weight(shape=(a, b)) + (self.a, self.b) = (a, b) + + def call(self, inputs): + result = ops.matmul(inputs, self.w) + return result + + def get_config(self): + base_config = super().get_config() + config = {"a": self.a, "b": self.b} + return base_config | config + + model = CustomModel6(5, 12) + self.assertEqual(model.a, 5) + self.assertEqual(model.b, 12) + self.assertTrue(isinstance(model, CustomModel6)) + self.assertFalse(isinstance(model, Functional)) + config = model.get_config() + self.assertFalse(has_functional_config_keys(config)) + with self.assertRaisesRegex( + TypeError, "unexpected keyword argument 'name'" + ): + CustomModel6.from_config(config) + + def test_reviving_functional_from_config_custom_model7(self): + # non-functional custom model + # CustomModel.__init__ called with functional-like args + # super().__init__ called with functional-like args + class CustomModel7(Model): + def __init__(self, a, b): + super().__init__(a, b) # nonsensical call that will fail + self.w = self.add_weight(shape=(a, b)) + (self.a, self.b) = (a, b) + + def call(self, inputs): + result = ops.matmul(inputs, self.w) + return result + + with self.assertRaisesRegex( + ValueError, "All `inputs` values must be KerasTensors" + ): + CustomModel7(5, 12) + @parameterized.named_parameters( ("single_output_1", _get_model_single_output), ("single_output_2", _get_model_single_output), diff --git a/keras/src/ops/function.py b/keras/src/ops/function.py index 18088cd3f5d9..8cd37c899f46 100644 --- a/keras/src/ops/function.py +++ b/keras/src/ops/function.py @@ -188,7 +188,12 @@ def _run_through_graph(self, inputs, operation_fn, call_fn=None): def _assert_input_compatibility(self, inputs): try: - tree.assert_same_structure(inputs, self._inputs_struct) + # We first normalize to tuples before performing the check to + # suppress warnings when encountering mismatched tuples and lists. + tree.assert_same_structure( + tree.lists_to_tuples(inputs), + tree.lists_to_tuples(self._inputs_struct), + ) except ValueError: raise ValueError( "Function was called with an invalid input structure. " diff --git a/keras/src/utils/model_visualization.py b/keras/src/utils/model_visualization.py index 1fd539961ba6..92498b7811bc 100644 --- a/keras/src/utils/model_visualization.py +++ b/keras/src/utils/model_visualization.py @@ -190,7 +190,13 @@ def make_node(layer, **kwargs): def remove_unused_edges(dot): - nodes = [v.get_name() for v in dot.get_nodes()] + nodes = [] + for sub in dot.get_subgraph_list(): + for node in sub.get_nodes(): + nodes.append(node.get_name()) + for node in dot.get_nodes(): + nodes.append(node.get_name()) + # nodes = [v.get_name() for v in dot.get_nodes()] for edge in dot.get_edges(): if edge.get_destination() not in nodes: dot.del_edge(edge.get_source(), edge.get_destination()) @@ -243,9 +249,32 @@ def model_to_dot( "the model on a batch of data." ) + from keras.src.layers.core.composite_layer import CompositeLayer from keras.src.models import functional from keras.src.models import sequential + # temporary workarounds until CompositeLayer is becomes + # the base class of all models and layers that have a + # "Keras functional" bahavior. + def is_functional(layer): + return ( + isinstance(layer, functional.Functional) + or isinstance(layer, sequential.Sequential) + or isinstance(layer, CompositeLayer) + ) + + def keras_function(layer): + if isinstance(layer, functional.Functional): + return layer + elif isinstance(layer, sequential.Sequential): + return layer + elif isinstance(layer, CompositeLayer): + return layer._function + else: + raise ValueError( + "Layer is not a Keras Functional model or CompositeLayer." + ) + # from keras.src.layers import Wrapper if not check_pydot(): @@ -281,20 +310,22 @@ def model_to_dot( if isinstance(model, sequential.Sequential): layers = model.layers - elif not isinstance(model, functional.Functional): + elif not is_functional(model): # We treat subclassed models as a single node. node = make_node(model, **kwargs) dot.add_node(node) return dot else: - layers = model._operations + # functional case + function = keras_function(model) + layers = function._operations # Create graph nodes. sub_n_first_node = {} sub_n_last_node = {} for i, layer in enumerate(layers): # Process nested functional models. - if expand_nested and isinstance(layer, functional.Functional): + if expand_nested and is_functional(layer): submodel = model_to_dot( layer, show_shapes, @@ -330,7 +361,7 @@ def model_to_dot( layer_id = str(id(layer)) for i, node in enumerate(layer._inbound_nodes): node_key = make_node_key(layer, i) - if node_key in model._nodes: + if node_key in function._nodes: for parent_node in node.parent_nodes: inbound_layer = parent_node.operation inbound_layer_id = str(id(inbound_layer)) @@ -340,25 +371,25 @@ def model_to_dot( add_edge(dot, inbound_layer_id, layer_id) else: # if inbound_layer is not Functional - if not isinstance(inbound_layer, functional.Functional): + if not is_functional(inbound_layer): # if current layer is not Functional - if not isinstance(layer, functional.Functional): + if not is_functional(layer): assert dot.get_node(inbound_layer_id) assert dot.get_node(layer_id) add_edge(dot, inbound_layer_id, layer_id) # if current layer is Functional - elif isinstance(layer, functional.Functional): + elif is_functional(layer): add_edge( dot, inbound_layer_id, sub_n_first_node[layer.name].get_name(), ) # if inbound_layer is Functional - elif isinstance(inbound_layer, functional.Functional): + elif is_functional(inbound_layer): name = sub_n_last_node[ inbound_layer.name ].get_name() - if isinstance(layer, functional.Functional): + if is_functional(layer): output_name = sub_n_first_node[ layer.name ].get_name()