From ee6287b69979f2a303ad48f1191001953a35da66 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sat, 13 Sep 2025 13:58:35 +0400 Subject: [PATCH 1/5] init: Add initial version --- keras_hub/src/models/mobilenetv5/__init__.py | 0 .../mobilenetv5/mobilenetv5_attention.py | 576 ++++++++++++++ .../mobilenetv5/mobilenetv5_backbone.py | 237 ++++++ .../mobilenetv5/mobilenetv5_backbone_test.py | 46 ++ .../models/mobilenetv5/mobilenetv5_blocks.py | 716 ++++++++++++++++++ .../models/mobilenetv5/mobilenetv5_builder.py | 335 ++++++++ .../mobilenetv5_image_classifier.py | 139 ++++ ...bilenetv5_image_classifier_preprocessor.py | 16 + .../mobilenetv5_image_classifier_test.py | 66 ++ .../mobilenetv5_image_converter.py | 10 + .../models/mobilenetv5/mobilenetv5_layers.py | 453 +++++++++++ .../models/mobilenetv5/mobilenetv5_utils.py | 146 ++++ .../convert_mobilenetv5_checkpoints.py | 469 ++++++++++++ 13 files changed, 3209 insertions(+) create mode 100644 keras_hub/src/models/mobilenetv5/__init__.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_backbone_test.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_test.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py create mode 100644 keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py create mode 100644 tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py diff --git a/keras_hub/src/models/mobilenetv5/__init__.py b/keras_hub/src/models/mobilenetv5/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py new file mode 100644 index 0000000000..f46a188ecb --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py @@ -0,0 +1,576 @@ +import keras + +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import DropPath +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import LayerScale2d +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d + + +class MultiQueryAttention2d(keras.layers.Layer): + """Implements 2D Multi-Query Attention. + + This layer performs attention on 2D spatial inputs. It uses a multi-query + attention mechanism where multiple query heads attend to a single key and + value. + + Args: + dim: int. The input and output channel dimension. + dim_out: int. The output channel dimension. If `None`, it is set to + `dim`. + num_heads: int. The number of attention heads. + key_dim: int. The dimension of the key. If `None`, it is calculated as + `dim // num_heads`. + value_dim: int. The dimension of the value. If `None`, it is calculated + as `dim // num_heads`. + query_strides: int or tuple. The stride for downsampling the query. + kv_stride: int. The stride for downsampling the key and value. + dw_kernel_size: int. The kernel size for the depthwise convolution used + for downsampling. + dilation: int. The dilation rate for the depthwise convolution. + padding: str. The padding type for convolutions. + attn_drop: float. The dropout rate for the attention weights. + proj_drop: float. The dropout rate for the output projection. + norm_layer: keras.layers.Layer. The normalization layer to use. + use_bias: bool. If `True`, bias terms are used in convolutions. + channel_axis: int. The axis representing the channels in the input + tensor. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + """ + + def __init__( + self, + dim, + dim_out=None, + num_heads=8, + key_dim=None, + value_dim=None, + query_strides=1, + kv_stride=1, + dw_kernel_size=3, + dilation=1, + padding="", + attn_drop=0.0, + proj_drop=0.0, + norm_layer=keras.layers.BatchNormalization, + use_bias=False, + channel_axis=None, + data_format=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.data_format = data_format + self.channel_axis = channel_axis + dim_out = dim_out or dim + self.num_heads = num_heads + self.key_dim = key_dim or dim // num_heads + self.value_dim = value_dim or dim // num_heads + self.query_strides = ( + query_strides + if isinstance(query_strides, (list, tuple)) + else (query_strides, query_strides) + ) + self.kv_stride = kv_stride + self.has_query_strides = any([s > 1 for s in self.query_strides]) + self.scale = self.key_dim**-0.5 + self.keras_padding = "same" if padding == "" else "valid" + self.conv_kernel_initializer = keras.initializers.VarianceScaling( + scale=2.0, mode="fan_out", distribution="untruncated_normal" + ) + self.bias_initializer = "zeros" + query_layers = [] + if self.has_query_strides: + pool_padding = "valid" if self.keras_padding == "valid" else "same" + query_layers.append( + keras.layers.AveragePooling2D( + pool_size=self.query_strides, + strides=self.query_strides, + padding=pool_padding, + data_format=self.data_format, + name="query_down_pool", + dtype=self.dtype_policy, + ) + ) + query_layers.append( + norm_layer( + axis=self.channel_axis, + name="query_norm", + gamma_initializer="ones", + beta_initializer="zeros", + dtype=self.dtype_policy, + ) + ) + query_layers.append( + keras.layers.Conv2D( + filters=self.num_heads * self.key_dim, + kernel_size=1, + use_bias=use_bias, + data_format=self.data_format, + name="query_proj", + kernel_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + ) + self.query_layers = query_layers + key_layers = [] + if kv_stride > 1: + if self.keras_padding == "same": + key_layers.append( + keras.layers.ZeroPadding2D( + padding=dw_kernel_size // 2, + data_format=self.data_format, + name="key_down_pad", + dtype=self.dtype_policy, + ) + ) + key_layers.append( + keras.layers.DepthwiseConv2D( + kernel_size=dw_kernel_size, + strides=kv_stride, + dilation_rate=dilation, + padding="valid", + data_format=self.data_format, + name="key_down_conv", + depthwise_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + use_bias=False, + dtype=self.dtype_policy, + ) + ) + key_layers.append( + norm_layer( + axis=self.channel_axis, + gamma_initializer="ones", + beta_initializer="zeros", + name="key_norm", + dtype=self.dtype_policy, + ) + ) + key_layers.append( + keras.layers.Conv2D( + filters=self.key_dim, + kernel_size=1, + padding="valid", + use_bias=use_bias, + data_format=self.data_format, + name="key_proj", + kernel_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + ) + self.key_layers = key_layers + value_layers = [] + if kv_stride > 1: + if self.keras_padding == "same": + value_layers.append( + keras.layers.ZeroPadding2D( + padding=dw_kernel_size // 2, + data_format=self.data_format, + name="value_down_pad", + dtype=self.dtype_policy, + ) + ) + value_layers.append( + keras.layers.DepthwiseConv2D( + kernel_size=dw_kernel_size, + strides=kv_stride, + dilation_rate=dilation, + padding="valid", + data_format=self.data_format, + name="value_down_conv", + depthwise_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + use_bias=False, + dtype=self.dtype_policy, + ) + ) + value_layers.append( + norm_layer( + axis=self.channel_axis, + gamma_initializer="ones", + beta_initializer="zeros", + name="value_norm", + dtype=self.dtype_policy, + ) + ) + value_layers.append( + keras.layers.Conv2D( + filters=self.value_dim, + kernel_size=1, + padding="valid", + use_bias=use_bias, + data_format=self.data_format, + name="value_proj", + kernel_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + ) + self.value_layers = value_layers + self.attn_drop = keras.layers.Dropout( + attn_drop, dtype=self.dtype_policy + ) + output_layers = [] + if self.has_query_strides: + output_layers.append( + keras.layers.UpSampling2D( + size=self.query_strides, + interpolation="bilinear", + data_format=self.data_format, + name="output_upsample", + dtype=self.dtype_policy, + ) + ) + output_layers.append( + keras.layers.Conv2D( + filters=dim_out, + kernel_size=1, + use_bias=use_bias, + data_format=self.data_format, + name="output_proj", + kernel_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + ) + output_layers.append( + keras.layers.Dropout(proj_drop, dtype=self.dtype_policy) + ) + self.output_proj_layers = output_layers + + def call(self, x, training=False): + B = keras.ops.shape(x)[0] + q = x + for layer in self.query_layers: + try: + q = layer(q, training=training) + except TypeError: + q = layer(q) + k = x + for layer in self.key_layers: + try: + k = layer(k, training=training) + except TypeError: + k = layer(k) + v = x + for layer in self.value_layers: + try: + v = layer(v, training=training) + except TypeError: + v = layer(v) + if self.data_format == "channels_last": + q = keras.ops.transpose(q, (0, 3, 1, 2)) + k = keras.ops.transpose(k, (0, 3, 1, 2)) + v = keras.ops.transpose(v, (0, 3, 1, 2)) + s_q = keras.ops.shape(q) + h_q, w_q = s_q[2], s_q[3] + q = keras.ops.reshape(q, (B, self.num_heads, self.key_dim, -1)) + q = keras.ops.transpose(q, (0, 1, 3, 2)) + k = keras.ops.reshape(k, (B, self.key_dim, -1)) + k = keras.ops.transpose(k, (0, 2, 1)) + k = keras.ops.expand_dims(k, axis=1) + v = keras.ops.reshape(v, (B, self.value_dim, -1)) + v = keras.ops.transpose(v, (0, 2, 1)) + v = keras.ops.expand_dims(v, axis=1) + q = q * self.scale + attn = keras.ops.matmul(q, keras.ops.transpose(k, (0, 1, 3, 2))) + attn = keras.ops.softmax(attn, axis=-1) + attn = self.attn_drop(attn, training=training) + o = keras.ops.matmul(attn, v) + o = keras.ops.transpose(o, (0, 2, 1, 3)) + feat_dim = self.num_heads * self.value_dim + o = keras.ops.reshape(o, (B, h_q, w_q, feat_dim)) + if self.data_format == "channels_first": + o = keras.ops.transpose(o, (0, 3, 1, 2)) + x_out = o + for layer in self.output_proj_layers: + try: + x_out = layer(x_out, training=training) + except TypeError: + x_out = layer(x_out) + return x_out + + +class Attention2d(keras.layers.Layer): + """Implements 2D Multi-Head Attention. + + This layer performs multi-head self-attention on 2D spatial inputs. + + Args: + dim: int. The input and output channel dimension. + dim_out: int. The output channel dimension. If `None`, it is set to + `dim`. + num_heads: int. The number of attention heads. + bias: bool. If `True`, bias terms are used in the qkv and projection + convolutions. + attn_drop: float. The dropout rate for the attention weights. + proj_drop: float. The dropout rate for the output projection. + channel_axis: int. The axis representing the channels in the input + tensor. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + """ + + def __init__( + self, + dim, + dim_out=None, + num_heads=32, + bias=True, + attn_drop=0.0, + proj_drop=0.0, + channel_axis=None, + data_format=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.data_format = data_format + self.channel_axis = channel_axis + dim_out = dim_out or dim + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.bias = bias + self.head_dim = dim // num_heads + self.conv_kernel_initializer = keras.initializers.VarianceScaling( + scale=2.0, mode="fan_out", distribution="untruncated_normal" + ) + self.bias_initializer = "zeros" + self.qkv = keras.layers.Conv2D( + dim * 3, + kernel_size=1, + use_bias=bias, + data_format=self.data_format, + name="qkv", + dtype=self.dtype_policy, + kernel_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + ) + self.attn_drop = keras.layers.Dropout( + attn_drop, dtype=self.dtype_policy + ) + self.proj = keras.layers.Conv2D( + dim_out, + kernel_size=1, + use_bias=bias, + data_format=self.data_format, + name="proj", + dtype=self.dtype_policy, + kernel_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + ) + self.proj_drop = keras.layers.Dropout( + proj_drop, dtype=self.dtype_policy + ) + + def call(self, x, attn_mask=None, training=False): + if self.data_format == "channels_first": + B, C, H, W = keras.ops.shape(x) + else: + B, H, W, C = keras.ops.shape(x) + qkv = self.qkv(x) + if self.data_format == "channels_last": + qkv = keras.ops.transpose(qkv, (0, 3, 1, 2)) + q, k, v = keras.ops.unstack( + keras.ops.reshape( + qkv, + (B, 3, self.num_heads, self.head_dim, H * W), + ), + axis=1, + ) + q = keras.ops.transpose(q, (0, 1, 3, 2)) + k = keras.ops.transpose(k, (0, 1, 2, 3)) + v = keras.ops.transpose(v, (0, 1, 3, 2)) + attn = keras.ops.matmul(q, k) * (self.head_dim**-0.5) + if attn_mask is not None: + attn = attn + attn_mask + attn = keras.ops.softmax(attn, axis=-1) + attn = self.attn_drop(attn, training=training) + x = keras.ops.matmul(attn, v) + x = keras.ops.transpose(x, (0, 1, 3, 2)) + if self.data_format == "channels_first": + x = keras.ops.reshape(x, (B, -1, H, W)) + else: + x = keras.ops.reshape(x, (B, H, W, -1)) + x = self.proj(x) + x = self.proj_drop(x, training=training) + return x + + +class MobileAttention(keras.layers.Layer): + """MobileNetV5 attention block. + + This block combines attention with depthwise convolutions for efficiency. + It can use either standard Multi-Head Attention or Multi-Query Attention. + + Args: + in_chs: int. The number of input channels. + out_chs: int. The number of output channels. + stride: int. The stride for the block. + dw_kernel_size: int. The kernel size for the depthwise convolution in + Multi-Query Attention. + dilation: int. The dilation rate for convolutions. + pad_type: str. The padding type for convolutions. + num_heads: int. The number of attention heads. + key_dim: int. The dimension of the key. + value_dim: int. The dimension of the value. + use_multi_query: bool. If `True`, use `MultiQueryAttention2d`, + otherwise use `Attention2d`. + query_strides: tuple. The strides for the query downsampling. + kv_stride: int. The stride for key/value downsampling. + cpe_dw_kernel_size: int. The kernel size for the conditional position + encoding depthwise convolution. + noskip: bool. If `True`, the skip connection is disabled. + norm_layer: str. The normalization layer to use (`"batch_norm"` or + `"rms_norm"`). + drop_path_rate: float. The stochastic depth rate. + attn_drop: float. The dropout rate for the attention weights. + proj_drop: float. The dropout rate for the output projection. + layer_scale_init_value: float. The initial value for layer scale. If + `None`, layer scale is not used. + use_bias: bool. If `True`, bias terms are used in convolutions. + use_cpe: bool. If `True`, a conditional position encoding is added. + channel_axis: int. The axis representing the channels in the input + tensor. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + """ + + def __init__( + self, + in_chs, + out_chs, + stride=1, + dw_kernel_size=3, + dilation=1, + pad_type="", + num_heads=8, + key_dim=64, + value_dim=64, + use_multi_query=False, + query_strides=(1, 1), + kv_stride=1, + cpe_dw_kernel_size=3, + noskip=False, + norm_layer="batch_norm", + drop_path_rate=0.0, + attn_drop=0.0, + proj_drop=0.0, + layer_scale_init_value=1e-5, + use_bias=False, + use_cpe=False, + channel_axis=None, + data_format=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.data_format = data_format + self.channel_axis = channel_axis + self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip + self.conv_kernel_initializer = keras.initializers.VarianceScaling( + scale=2.0, mode="fan_out", distribution="untruncated_normal" + ) + self.bias_initializer = "zeros" + if use_cpe: + self.conv_cpe_dw = keras.layers.DepthwiseConv2D( + kernel_size=cpe_dw_kernel_size, + strides=1, + padding="same", + dilation_rate=dilation, + use_bias=True, + data_format=self.data_format, + name="conv_cpe_dw", + depthwise_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + else: + self.conv_cpe_dw = None + if norm_layer == "batch_norm": + self.norm = keras.layers.BatchNormalization( + axis=self.channel_axis, + name="norm", + gamma_initializer="ones", + beta_initializer="zeros", + dtype=self.dtype_policy, + ) + elif norm_layer == "rms_norm": + self.norm = RmsNorm2d( + in_chs, + data_format=self.data_format, + gamma_initializer="ones", + channel_axis=self.channel_axis, + name="norm", + dtype=self.dtype_policy, + ) + else: + raise ValueError(f"Unsupported norm_layer: {norm_layer}") + if num_heads is None: + assert in_chs % key_dim == 0 + num_heads = in_chs // key_dim + if use_multi_query: + self.attn = MultiQueryAttention2d( + dim=in_chs, + dim_out=out_chs, + num_heads=num_heads, + key_dim=key_dim, + value_dim=value_dim, + query_strides=query_strides, + kv_stride=kv_stride, + dw_kernel_size=dw_kernel_size, + dilation=dilation, + padding=pad_type, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=keras.layers.BatchNormalization, + use_bias=use_bias, + channel_axis=self.channel_axis, + data_format=self.data_format, + name="attn", + dtype=self.dtype_policy, + ) + else: + self.attn = Attention2d( + dim=in_chs, + dim_out=out_chs, + num_heads=num_heads, + attn_drop=attn_drop, + proj_drop=proj_drop, + bias=use_bias, + channel_axis=self.channel_axis, + data_format=self.data_format, + name="attn", + dtype=self.dtype_policy, + ) + if layer_scale_init_value is not None: + self.layer_scale = LayerScale2d( + out_chs, + layer_scale_init_value, + name="layer_scale", + channel_axis=self.channel_axis, + data_format=self.data_format, + dtype=self.dtype_policy, + ) + else: + self.layer_scale = lambda x: x + self.drop_path = ( + DropPath(drop_path_rate, dtype=self.dtype_policy) + if drop_path_rate > 0.0 + else lambda x, training: x + ) + + def call(self, x, training=False): + if self.conv_cpe_dw is not None: + x = x + self.conv_cpe_dw(x) + shortcut = x + x_normed = self.norm(x, training=training) + x_attn = self.attn(x_normed, training=training) + x_scaled = self.layer_scale(x_attn) + if self.has_skip: + return self.drop_path(x_scaled, training=training) + shortcut + else: + return x_scaled diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py new file mode 100644 index 0000000000..4c2632953d --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_backbone.py @@ -0,0 +1,237 @@ +import keras +from keras.src import saving + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.mobilenet.mobilenet_backbone import SqueezeAndExcite2D +from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import ( + MobileNetV5MultiScaleFusionAdapter, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + MobileNetV5Builder, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct +from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import ( + feature_take_indices, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import round_channels +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.MobileNetV5Backbone") +class MobileNetV5Backbone(Backbone): + """MobileNetV5 backbone network. + + This class represents the backbone of the MobileNetV5 architecture, which + can be used as a feature extractor for various downstream tasks. + + Args: + block_args: list. A list of lists, where each inner list contains the + arguments for the blocks in a stage. + in_chans: int. The number of input channels. + stem_size: int. The number of channels in the stem convolution. + stem_bias: bool. If `True`, a bias term is used in the stem + convolution. + fix_stem: bool. If `True`, the stem size is not rounded. + num_features: int. The number of output features, used when `use_msfa` + is `True`. + pad_type: str. The padding type for convolutions. + use_msfa: bool. If `True`, the Multi-Scale Fusion Adapter is used. + msfa_indices: tuple. The indices of the feature maps to be used by the + MSFA. + msfa_output_resolution: int. The output resolution of the MSFA. + act_layer: str. The activation function to use. + norm_layer: str. The normalization layer to use. + se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use. + se_from_exp: bool. If `True`, SE channel reduction is based on the + expanded channels. + round_chs_fn: callable. A function to round the number of channels. + drop_path_rate: float. The stochastic depth rate. + layer_scale_init_value: float. The initial value for layer scale. + image_shape: tuple. The shape of the input image. + data_format: str, The data format of the image channels. Can be either + `"channels_first"` or `"channels_last"`. If `None` is specified, + it will use the `image_data_format` value found in your Keras + config file at `~/.keras/keras.json`. Defaults to `None`. + dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype + to use for the model's computations and weights. Defaults to `None`. + + Example: + ```python + import keras + from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + decode_arch_def + ) + + arch_def = [["er_r1_k3_s2_e4_c24"], ["uir_r2_k5_s2_e6_c48"]] + block_args = decode_arch_def(arch_def) + model = keras_hub.models.MobileNetV5Backbone(block_args=block_args) + # Create a dummy input. + input_data = keras.ops.ones((1, 224, 224, 3)) + output = model(input_data) + ``` + """ + + def __init__( + self, + block_args, + in_chans=3, + stem_size=16, + stem_bias=True, + fix_stem=False, + num_features=2048, + pad_type="", + use_msfa=True, + msfa_indices=(-2, -1), + msfa_output_resolution=16, + act_layer="gelu", + norm_layer="rms_norm", + se_layer=SqueezeAndExcite2D, + se_from_exp=True, + round_chs_fn=round_channels, + drop_path_rate=0.0, + layer_scale_init_value=None, + image_shape=(None, None, 3), + data_format=None, + dtype=None, + **kwargs, + ): + data_format = standardize_data_format(data_format) + channel_axis = -1 if data_format == "channels_last" else 1 + + # === Layers === + if not fix_stem: + stem_size = round_chs_fn(stem_size) + conv_stem = ConvNormAct( + stem_size, + kernel_size=3, + stride=2, + pad_type=pad_type, + bias=stem_bias, + norm_layer=norm_layer, + act_layer=act_layer, + name="conv_stem", + data_format=data_format, + channel_axis=channel_axis, + dtype=dtype, + ) + builder = MobileNetV5Builder( + output_stride=32, + pad_type=pad_type, + round_chs_fn=round_chs_fn, + se_from_exp=se_from_exp, + act_layer=act_layer, + norm_layer=norm_layer, + se_layer=se_layer, + drop_path_rate=drop_path_rate, + layer_scale_init_value=layer_scale_init_value, + data_format=data_format, + channel_axis=channel_axis, + dtype=dtype, + ) + blocks = builder(stem_size, block_args) + feature_info = builder.features + msfa = None + if use_msfa: + msfa_indices_calc, _ = feature_take_indices( + len(feature_info), msfa_indices + ) + msfa_in_chs = [ + feature_info[mi]["num_chs"] for mi in msfa_indices_calc + ] + msfa = MobileNetV5MultiScaleFusionAdapter( + in_chs=msfa_in_chs, + out_chs=num_features, + output_resolution=msfa_output_resolution, + norm_layer=norm_layer, + act_layer=act_layer, + name="msfa", + channel_axis=channel_axis, + data_format=data_format, + dtype=dtype, + ) + + # === Functional Model === + image_input = keras.layers.Input(shape=image_shape) + x = conv_stem(image_input) + if use_msfa: + intermediates = [] + feat_idx = 0 + if feat_idx in msfa_indices_calc: + intermediates.append(x) + + for stage in blocks: + for block in stage: + x = block(x) + feat_idx += 1 + if feat_idx in msfa_indices_calc: + intermediates.append(x) + x = msfa(intermediates) + else: + for stage in blocks: + for block in stage: + x = block(x) + + super().__init__(inputs=image_input, outputs=x, dtype=dtype, **kwargs) + + # === Config === + self.block_args = block_args + self.in_chans = in_chans + self.stem_size = stem_size + self.stem_bias = stem_bias + self.fix_stem = fix_stem + self.num_features = num_features + self.pad_type = pad_type + self.use_msfa = use_msfa + self.msfa_indices = msfa_indices + self.msfa_output_resolution = msfa_output_resolution + self.act_layer = act_layer + self.norm_layer = norm_layer + self.se_layer = se_layer + self.se_from_exp = se_from_exp + self.round_chs_fn = round_chs_fn + self.drop_path_rate = drop_path_rate + self.layer_scale_init_value = layer_scale_init_value + self.image_shape = image_shape + self.data_format = data_format + self.channel_axis = channel_axis + + def get_config(self): + config = { + "block_args": self.block_args, + "in_chans": self.in_chans, + "stem_size": self.stem_size, + "stem_bias": self.stem_bias, + "fix_stem": self.fix_stem, + "num_features": self.num_features, + "pad_type": self.pad_type, + "use_msfa": self.use_msfa, + "msfa_indices": self.msfa_indices, + "msfa_output_resolution": self.msfa_output_resolution, + "act_layer": self.act_layer, + "norm_layer": self.norm_layer, + "se_from_exp": self.se_from_exp, + "drop_path_rate": self.drop_path_rate, + "layer_scale_init_value": self.layer_scale_init_value, + "image_shape": self.image_shape, + "data_format": self.data_format, + } + if self.round_chs_fn is not round_channels: + config["round_chs_fn"] = saving.serialize_keras_object( + self.round_chs_fn + ) + if self.se_layer is not SqueezeAndExcite2D: + config["se_layer"] = saving.serialize_keras_object(self.se_layer) + return config + + @classmethod + def from_config(cls, config): + if "round_chs_fn" in config: + config["round_chs_fn"] = saving.deserialize_keras_object( + config["round_chs_fn"] + ) + if "se_layer" in config: + config["se_layer"] = saving.deserialize_keras_object( + config["se_layer"] + ) + return cls(**config) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_backbone_test.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_backbone_test.py new file mode 100644 index 0000000000..d8c74109e4 --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_backbone_test.py @@ -0,0 +1,46 @@ +import keras +import pytest + +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import decode_arch_def +from keras_hub.src.tests.test_case import TestCase + + +class MobileNetV5BackboneTest(TestCase): + def setUp(self): + arch_def = [ + ["er_r1_k3_s2_e4_c24"], + ["uir_r2_k5_s2_e6_c48"], + ] + block_args = decode_arch_def(arch_def) + + self.init_kwargs = { + "block_args": block_args, + "image_shape": (32, 32, 3), + "stem_size": 16, + "use_msfa": False, + } + self.input_data = keras.ops.ones((2, 32, 32, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=MobileNetV5Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=( + 2, + 4, + 4, + 48, + ), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetV5Backbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py new file mode 100644 index 0000000000..72d76c72d6 --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py @@ -0,0 +1,716 @@ +import keras + +from keras_hub.src.models.mobilenet.util import adjust_channels +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import DropPath +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import LayerScale2d +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d +from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import num_groups + + +class UniversalInvertedResidual(keras.layers.Layer): + """Universal Inverted Residual block. + + This block is a flexible and universal version of the inverted residual + block, which can be configured to behave like different variants of mobile + convolutional blocks. + + Args: + in_chs: int. The number of input channels. + out_chs: int. The number of output channels. + dw_kernel_size_start: int. The kernel size for the initial depthwise + convolution. If 0, this layer is skipped. + dw_kernel_size_mid: int. The kernel size for the middle depthwise + convolution. If 0, this layer is skipped. + dw_kernel_size_end: int. The kernel size for the final depthwise + convolution. If 0, this layer is skipped. + stride: int. The stride for the block. + dilation: int. The dilation rate for convolutions. + pad_type: str. The padding type for convolutions. + noskip: bool. If `True`, the skip connection is disabled. + exp_ratio: float. The expansion ratio for the middle channels. + act_layer: str. The activation function to use. + norm_layer: str. The normalization layer to use. + se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use. + drop_path_rate: float. The stochastic depth rate. + layer_scale_init_value: float. The initial value for layer scale. If + `None`, layer scale is not used. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + in_chs, + out_chs, + dw_kernel_size_start=0, + dw_kernel_size_mid=3, + dw_kernel_size_end=0, + stride=1, + dilation=1, + pad_type="", + noskip=False, + exp_ratio=1.0, + act_layer="relu", + norm_layer="batch_norm", + se_layer=None, + drop_path_rate=0.0, + layer_scale_init_value=1e-5, + data_format=None, + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + self.in_chs = in_chs + self.out_chs = out_chs + self.data_format = data_format + self.channel_axis = channel_axis + keras_pad_type = "same" if pad_type == "" else "valid" + use_bias = norm_layer == "rms_norm" + + if dw_kernel_size_start: + self.dw_start = ConvNormAct( + in_chs, + dw_kernel_size_start, + stride=stride if not dw_kernel_size_mid else 1, + dilation=dilation, + groups=in_chs, + pad_type=keras_pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, + bias=use_bias, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + else: + self.dw_start = lambda x, training=False: x + + mid_chs = adjust_channels(in_chs * exp_ratio) + self.pw_exp = ConvNormAct( + mid_chs, + 1, + pad_type=keras_pad_type, + act_layer=act_layer, + norm_layer=norm_layer, + bias=use_bias, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + + if dw_kernel_size_mid: + self.dw_mid = ConvNormAct( + mid_chs, + dw_kernel_size_mid, + stride=stride, + dilation=dilation, + groups=mid_chs, + pad_type=keras_pad_type, + act_layer=act_layer, + norm_layer=norm_layer, + bias=use_bias, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + else: + self.dw_mid = lambda x, training=False: x + self.se = ( + se_layer( + filters=mid_chs, + bottleneck_filters=adjust_channels(mid_chs * 0.25), + squeeze_activation=act_layer, + excite_activation="sigmoid", + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + if se_layer + else (lambda x, training=False: x) + ) + self.pw_proj = ConvNormAct( + out_chs, + 1, + pad_type=keras_pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, + bias=use_bias, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + + if dw_kernel_size_end: + self.dw_end = ConvNormAct( + out_chs, + dw_kernel_size_end, + stride=stride + if not dw_kernel_size_start and not dw_kernel_size_mid + else 1, + dilation=dilation, + groups=out_chs, + pad_type=keras_pad_type, + apply_act=False, + act_layer=act_layer, + norm_layer=norm_layer, + bias=use_bias, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + else: + self.dw_end = lambda x, training=False: x + + self.layer_scale = ( + LayerScale2d( + out_chs, + layer_scale_init_value, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + if layer_scale_init_value is not None + else lambda x: x + ) + self.drop_path = ( + DropPath(drop_path_rate, dtype=self.dtype_policy) + if drop_path_rate > 0.0 + else (lambda x, training=False: x) + ) + + def build(self, input_shape): + current_shape = input_shape + if hasattr(self.dw_start, "build"): + self.dw_start.build(current_shape) + current_shape = self.dw_start.compute_output_shape(current_shape) + self.pw_exp.build(current_shape) + current_shape = self.pw_exp.compute_output_shape(current_shape) + if hasattr(self.dw_mid, "build"): + self.dw_mid.build(current_shape) + current_shape = self.dw_mid.compute_output_shape(current_shape) + if hasattr(self.se, "build"): + self.se.build(current_shape) + self.pw_proj.build(current_shape) + current_shape = self.pw_proj.compute_output_shape(current_shape) + if hasattr(self.dw_end, "build"): + self.dw_end.build(current_shape) + current_shape = self.dw_end.compute_output_shape(current_shape) + if hasattr(self.layer_scale, "build"): + self.layer_scale.build(current_shape) + self.built = True + + def call(self, x, training=False): + shortcut = x + x = self.dw_start(x, training=training) + x = self.pw_exp(x, training=training) + x = self.dw_mid(x, training=training) + x = self.se(x, training=training) + x = self.pw_proj(x, training=training) + x = self.dw_end(x, training=training) + x = self.layer_scale(x) + if self.has_skip: + x = self.drop_path(x, training=training) + shortcut + return x + + def compute_output_shape(self, input_shape): + current_shape = input_shape + if hasattr(self.dw_start, "compute_output_shape"): + current_shape = self.dw_start.compute_output_shape(current_shape) + current_shape = self.pw_exp.compute_output_shape(current_shape) + if hasattr(self.dw_mid, "compute_output_shape"): + current_shape = self.dw_mid.compute_output_shape(current_shape) + current_shape = self.pw_proj.compute_output_shape(current_shape) + if hasattr(self.dw_end, "compute_output_shape"): + current_shape = self.dw_end.compute_output_shape(current_shape) + return current_shape + + +class EdgeResidual(keras.layers.Layer): + """Edge Residual block. + + This block is designed for efficiency on edge devices. It is a variant of + the inverted residual block that uses a single expansion convolution. + + Args: + in_chs: int. The number of input channels. + out_chs: int. The number of output channels. + exp_kernel_size: int. The kernel size for the expansion convolution. + stride: int. The stride for the block. + dilation: int. The dilation rate for convolutions. + group_size: int. The group size for grouped convolutions. + pad_type: str. The padding type for convolutions. + force_in_chs: int. If greater than 0, forces the number of input + channels for the expansion. + noskip: bool. If `True`, the skip connection is disabled. + exp_ratio: float. The expansion ratio for the middle channels. + pw_kernel_size: int. The kernel size for the pointwise convolution. + act_layer: str. The activation function to use. + norm_layer: str. The normalization layer to use. + se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use. + drop_path_rate: float. The stochastic depth rate. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + in_chs, + out_chs, + exp_kernel_size=3, + stride=1, + dilation=1, + group_size=0, + pad_type="", + force_in_chs=0, + noskip=False, + exp_ratio=1.0, + pw_kernel_size=1, + act_layer="relu", + norm_layer="batch_norm", + se_layer=None, + drop_path_rate=0.0, + data_format=None, + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + self.in_chs = in_chs + self.out_chs = out_chs + self.data_format = data_format + self.channel_axis = channel_axis + keras_pad_type = "same" if pad_type == "" else "valid" + if force_in_chs > 0: + mid_chs = adjust_channels(force_in_chs * exp_ratio) + else: + mid_chs = adjust_channels(in_chs * exp_ratio) + groups = num_groups(group_size, mid_chs) + use_bias = norm_layer == "rms_norm" + self.conv_exp = ConvNormAct( + mid_chs, + exp_kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + pad_type=keras_pad_type, + norm_layer=norm_layer, + act_layer=act_layer, + bias=use_bias, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + self.se = ( + se_layer( + filters=mid_chs, + bottleneck_filters=adjust_channels(mid_chs * 0.25), + squeeze_activation=act_layer, + excite_activation="sigmoid", + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + if se_layer + else (lambda x, training=False: x) + ) + self.conv_pwl = ConvNormAct( + out_chs, + pw_kernel_size, + pad_type=keras_pad_type, + apply_act=False, + norm_layer=norm_layer, + act_layer=act_layer, + bias=use_bias, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + self.drop_path = ( + DropPath(drop_path_rate, dtype=self.dtype_policy) + if drop_path_rate > 0.0 + else (lambda x, training=False: x) + ) + + def build(self, input_shape): + self.conv_exp.build(input_shape) + conv_exp_output_shape = self.conv_exp.compute_output_shape(input_shape) + if hasattr(self.se, "build"): + self.se.build(conv_exp_output_shape) + self.conv_pwl.build(conv_exp_output_shape) + self.built = True + + def call(self, x, training=False): + shortcut = x + x = self.conv_exp(x, training=training) + x = self.se(x, training=training) + x = self.conv_pwl(x, training=training) + if self.has_skip: + x = self.drop_path(x, training=training) + shortcut + return x + + +class CondConvResidual(keras.layers.Layer): + """Conditionally Parameterized Convolutional Residual block. + + This block uses a routing function to dynamically select and combine + different convolutional experts based on the input. + + Args: + in_chs: int. The number of input channels. + out_chs: int. The number of output channels. + dw_kernel_size: int. The kernel size for the depthwise convolution. + stride: int. The stride for the block. + dilation: int. The dilation rate for convolutions. + pad_type: str. The padding type for convolutions. + noskip: bool. If `True`, the skip connection is disabled. + exp_ratio: float. The expansion ratio for the middle channels. + exp_kernel_size: int. The kernel size for the expansion convolution. + pw_kernel_size: int. The kernel size for the pointwise convolution. + act_layer: str. The activation function to use. + se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use. + num_experts: int. The number of experts to use. + drop_path_rate: float. The stochastic depth rate. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + in_chs, + out_chs, + dw_kernel_size=3, + stride=1, + dilation=1, + pad_type="", + noskip=False, + exp_ratio=1.0, + exp_kernel_size=1, + pw_kernel_size=1, + act_layer="relu", + se_layer=None, + num_experts=0, + drop_path_rate=0.0, + data_format=None, + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + self.num_experts = num_experts + self.data_format = data_format + self.channel_axis = channel_axis + self.conv_kernel_initializer = keras.initializers.VarianceScaling( + scale=2.0, mode="fan_out", distribution="untruncated_normal" + ) + self.dense_kernel_initializer = keras.initializers.VarianceScaling( + scale=1.0, mode="fan_in", distribution="uniform" + ) + self.bias_initializer = "zeros" + mid_chs = adjust_channels(in_chs * exp_ratio) + keras_pad_type = "same" if pad_type == "" else "valid" + self.routing_fn = keras.layers.Dense( + self.num_experts, + dtype=self.dtype_policy, + kernel_initializer=self.dense_kernel_initializer, + bias_initializer=self.bias_initializer, + ) + self.pool = keras.layers.GlobalAveragePooling2D( + data_format=self.data_format, dtype=self.dtype_policy + ) + self.conv_pw_experts = [ + keras.layers.Conv2D( + filters=mid_chs, + kernel_size=exp_kernel_size, + padding=keras_pad_type, + use_bias=True, + data_format=self.data_format, + name=f"conv_pw_expert_{i}", + kernel_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + for i in range(self.num_experts) + ] + self.conv_dw_experts = [ + keras.layers.DepthwiseConv2D( + kernel_size=dw_kernel_size, + strides=stride, + padding=keras_pad_type, + dilation_rate=dilation, + use_bias=True, + data_format=self.data_format, + name=f"conv_dw_expert_{i}", + depthwise_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + for i in range(self.num_experts) + ] + self.conv_pwl_experts = [ + keras.layers.Conv2D( + filters=out_chs, + kernel_size=pw_kernel_size, + padding=keras_pad_type, + use_bias=True, + data_format=self.data_format, + name=f"conv_pwl_expert_{i}", + kernel_initializer=self.conv_kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + for i in range(self.num_experts) + ] + self.bn1 = keras.layers.BatchNormalization( + axis=self.channel_axis, + dtype=self.dtype_policy, + gamma_initializer="ones", + beta_initializer="zeros", + ) + self.act1 = keras.layers.Activation(act_layer, dtype=self.dtype_policy) + self.bn2 = keras.layers.BatchNormalization( + axis=self.channel_axis, + dtype=self.dtype_policy, + gamma_initializer="ones", + beta_initializer="zeros", + ) + self.act2 = keras.layers.Activation(act_layer, dtype=self.dtype_policy) + self.bn3 = keras.layers.BatchNormalization( + axis=self.channel_axis, + dtype=self.dtype_policy, + gamma_initializer="ones", + beta_initializer="zeros", + ) + self.se = ( + se_layer( + filters=mid_chs, + bottleneck_filters=adjust_channels(mid_chs * 0.25), + squeeze_activation=act_layer, + excite_activation="sigmoid", + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + if se_layer + else (lambda x, training=False: x) + ) + self.drop_path = ( + DropPath(drop_path_rate, dtype=self.dtype_policy) + if drop_path_rate > 0.0 + else (lambda x, training=False: x) + ) + + def build(self, input_shape): + pooled_shape = self.pool.compute_output_shape(input_shape) + self.routing_fn.build(pooled_shape) + for expert in self.conv_pw_experts: + expert.build(input_shape) + pw_out_shape = self.conv_pw_experts[0].compute_output_shape(input_shape) + self.bn1.build(pw_out_shape) + for expert in self.conv_dw_experts: + expert.build(pw_out_shape) + dw_out_shape = self.conv_dw_experts[0].compute_output_shape( + pw_out_shape + ) + self.bn2.build(dw_out_shape) + if hasattr(self.se, "build"): + self.se.build(dw_out_shape) + for expert in self.conv_pwl_experts: + expert.build(dw_out_shape) + pwl_out_shape = self.conv_pwl_experts[0].compute_output_shape( + dw_out_shape + ) + self.bn3.build(pwl_out_shape) + self.built = True + + def _apply_cond_conv(self, x, experts, routing_weights): + outputs = [] + for i, expert in enumerate(experts): + expert_out = expert(x) + weight = keras.ops.reshape(routing_weights[:, i], (-1, 1, 1, 1)) + outputs.append(expert_out * weight) + return keras.ops.sum(outputs, axis=0) + + def call(self, x, training=False): + shortcut = x + pooled_inputs = self.pool(x) + routing_weights = keras.activations.sigmoid( + self.routing_fn(pooled_inputs) + ) + x = self._apply_cond_conv(x, self.conv_pw_experts, routing_weights) + x = self.bn1(x, training=training) + x = self.act1(x) + x = self._apply_cond_conv(x, self.conv_dw_experts, routing_weights) + x = self.bn2(x, training=training) + x = self.act2(x) + x = self.se(x, training=training) + x = self._apply_cond_conv(x, self.conv_pwl_experts, routing_weights) + x = self.bn3(x, training=training) + if self.has_skip: + x = self.drop_path(x, training=training) + shortcut + return x + + +class MobileNetV5MultiScaleFusionAdapter(keras.layers.Layer): + """Multi-Scale Fusion Adapter for MobileNetV5. + + This layer fuses feature maps from different scales of the backbone, + concatenates them, processes them through a FFN (Feed-Forward Network), + and then resizes the output to a target resolution. + + Args: + in_chs: list of int. A list of channel counts for each input feature + map. + out_chs: int. The number of output channels. + output_resolution: int or tuple. The target output resolution. + expansion_ratio: float. The expansion ratio for the FFN. + interpolation_mode: str. The interpolation mode for upsampling feature + maps. + layer_scale_init_value: float. The initial value for layer scale. If + `None`, layer scale is not used. + noskip: bool. If `True`, the skip connection in the FFN is disabled. + act_layer: str. The activation function to use. + norm_layer: str. The normalization layer to use. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + in_chs, + out_chs, + output_resolution, + expansion_ratio=2.0, + interpolation_mode="nearest", + layer_scale_init_value=None, + noskip=True, + act_layer="gelu", + norm_layer="rms_norm", + data_format=None, + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.in_channels = sum(in_chs) + self.out_channels = out_chs + self.data_format = data_format + self.channel_axis = channel_axis + if isinstance(output_resolution, int): + self.output_resolution = (output_resolution, output_resolution) + else: + self.output_resolution = output_resolution + self.interpolation_mode = interpolation_mode + self.ffn = UniversalInvertedResidual( + in_chs=self.in_channels, + out_chs=self.out_channels, + dw_kernel_size_mid=0, + exp_ratio=expansion_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + noskip=noskip, + layer_scale_init_value=layer_scale_init_value, + data_format=self.data_format, + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + if norm_layer == "rms_norm": + self.norm = RmsNorm2d( + self.out_channels, + data_format=self.data_format, + gamma_initializer="ones", + channel_axis=self.channel_axis, + dtype=self.dtype_policy, + ) + else: + self.norm = keras.layers.BatchNormalization( + axis=self.channel_axis, + gamma_initializer="ones", + beta_initializer="zeros", + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + ffn_input_shape = list(input_shape[0]) + if self.data_format == "channels_first": + ffn_input_shape[1] = self.in_channels + else: + ffn_input_shape[-1] = self.in_channels + self.ffn.build(tuple(ffn_input_shape)) + norm_input_shape = self.ffn.compute_output_shape(tuple(ffn_input_shape)) + self.norm.build(norm_input_shape) + self.built = True + + def call(self, inputs, training=False): + shape_hr = keras.ops.shape(inputs[0]) + if self.data_format == "channels_first": + high_resolution = (shape_hr[2], shape_hr[3]) + else: + high_resolution = (shape_hr[1], shape_hr[2]) + resized_inputs = [] + for img in inputs: + if self.data_format == "channels_first": + img_transposed = keras.ops.transpose(img, (0, 2, 3, 1)) + else: + img_transposed = img + img_resized = keras.ops.image.resize( + img_transposed, + size=high_resolution, + interpolation=self.interpolation_mode, + ) + if self.data_format == "channels_first": + resized_inputs.append( + keras.ops.transpose(img_resized, (0, 3, 1, 2)) + ) + else: + resized_inputs.append(img_resized) + channel_cat_imgs = keras.ops.concatenate( + resized_inputs, axis=self.channel_axis + ) + img = self.ffn(channel_cat_imgs, training=training) + if self.data_format == "channels_first": + img_transposed = keras.ops.transpose(img, (0, 2, 3, 1)) + else: + img_transposed = img + img_resized = keras.ops.image.resize( + img_transposed, + size=self.output_resolution, + interpolation="bilinear", + ) + if self.data_format == "channels_first": + img = keras.ops.transpose(img_resized, (0, 3, 1, 2)) + else: + img = img_resized + img = self.norm(img, training=training) + return img + + def compute_output_shape(self, input_shape): + batch_size = input_shape[0][0] + if self.data_format == "channels_first": + return ( + batch_size, + self.out_channels, + self.output_resolution[0], + self.output_resolution[1], + ) + else: + return ( + batch_size, + self.output_resolution[0], + self.output_resolution[1], + self.out_channels, + ) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py new file mode 100644 index 0000000000..eb4d7034e0 --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py @@ -0,0 +1,335 @@ +import re +from copy import deepcopy + +from keras_hub.src.models.mobilenet.mobilenet_backbone import ConvBnActBlock +from keras_hub.src.models.mobilenet.mobilenet_backbone import DepthwiseConvBlock +from keras_hub.src.models.mobilenet.mobilenet_backbone import ( + InvertedResidualBlock, +) +from keras_hub.src.models.mobilenet.util import adjust_channels +from keras_hub.src.models.mobilenetv5.mobilenetv5_attention import ( + MobileAttention, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import CondConvResidual +from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import EdgeResidual +from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import ( + UniversalInvertedResidual, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import parse_ksize +from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import round_channels + + +def decode_block_str(block_str): + assert isinstance(block_str, str) + ops = block_str.split("_") + block_type = ops[0] + ops = ops[1:] + options = {} + skip = None + for op in ops: + if op == "noskip": + skip = False + elif op == "skip": + skip = True + elif op.startswith("n"): + key = op[0] + v = op[1:] + options[key] = v if v else "relu" + else: + splits = re.split(r"(\d.*)", op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + act_layer = options.get("n") + num_repeat = int(options["r"]) + + block_args = dict( + block_type=block_type, + out_chs=int(options["c"]), + stride=int(options["s"]), + act_layer=act_layer, + ) + + if block_type == "uir": + start_kernel_size = parse_ksize(options.get("a", "0")) + end_kernel_size = parse_ksize(options.get("p", "0")) + block_args.update( + dict( + dw_kernel_size_start=start_kernel_size, + dw_kernel_size_mid=parse_ksize(options["k"]), + dw_kernel_size_end=end_kernel_size, + exp_ratio=float(options["e"]), + se_ratio=float(options.get("se", 0.0)), + noskip=skip is False, + ) + ) + elif block_type == "er": + block_args.update( + dict( + exp_kernel_size=parse_ksize(options["k"]), + pw_kernel_size=1, + exp_ratio=float(options["e"]), + se_ratio=float(options.get("se", 0.0)), + noskip=skip is False, + ) + ) + elif block_type in ("mqa", "mha"): + block_args.update( + dict( + num_heads=int(options.get("h", "12")), + key_dim=int(options.get("d", "64")), + use_cpe=bool(int(options.get("cpe", "0"))), + ) + ) + return block_args, num_repeat + + +def decode_arch_def(arch_def): + arch_args = [] + for _, block_strings in enumerate(arch_def): + stack_args = [] + for block_str in block_strings: + ba, rep = decode_block_str(block_str) + stack_args.extend([deepcopy(ba) for _ in range(rep)]) + arch_args.append(stack_args) + return arch_args + + +class MobileNetV5Builder: + """Builds a MobileNetV5 model from a decoded architecture definition. + + This class takes a decoded architecture definition and constructs a list of + network stages, where each stage is a list of blocks. It handles channel + rounding, stride management, and feature extraction points. + + Args: + output_stride: int. The desired output stride of the network. + pad_type: str. The padding type for convolutions. + round_chs_fn: callable. A function to round the number of channels. + se_from_exp: bool. If `True`, SE channel reduction is based on the + expanded channels. + act_layer: str. The default activation function for blocks. + norm_layer: str. The default normalization layer for blocks. + aa_layer: keras.layers.Layer. An optional anti-aliasing layer. + se_layer: keras.layers.Layer. The Squeeze-and-Excitation layer to use. + drop_path_rate: float. The stochastic depth rate for the network. + layer_scale_init_value: float. The initial value for layer scale. + feature_location: str. Where to extract features from, either + `"bottleneck"`, `"expansion"`, or `""`. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + output_stride=32, + pad_type="", + round_chs_fn=round_channels, + se_from_exp=False, + act_layer="relu", + norm_layer="batch_norm", + aa_layer=None, + se_layer=None, + drop_path_rate=0.0, + layer_scale_init_value=None, + feature_location="", + data_format=None, + channel_axis=None, + dtype=None, + ): + self.output_stride = output_stride + self.pad_type = pad_type + self.data_format = data_format + self.channel_axis = channel_axis + self.round_chs_fn = round_chs_fn + self.se_from_exp = se_from_exp + self.act_layer = act_layer + self.norm_layer = norm_layer + self.aa_layer = aa_layer + self.se_layer = se_layer + self.drop_path_rate = drop_path_rate + self.layer_scale_init_value = layer_scale_init_value + self.dtype = dtype + if feature_location == "depthwise": + feature_location = "expansion" + self.feature_location = feature_location + assert feature_location in ("bottleneck", "expansion", "") + self.in_chs = None + self.features = [] + + def _make_block(self, ba, block_idx, block_count): + drop_path_rate = self.drop_path_rate * block_idx / block_count + bt = ba.pop("block_type") + ba["in_chs"] = self.in_chs + ba["out_chs"] = self.round_chs_fn(ba["out_chs"]) + s2d = ba.get("s2d", 0) + if s2d > 0: + ba["out_chs"] *= 4 + if "force_in_chs" in ba and ba["force_in_chs"]: + ba["force_in_chs"] = self.round_chs_fn(ba["force_in_chs"]) + ba["pad_type"] = self.pad_type + ba["act_layer"] = ( + ba.get("act_layer") + if ba.get("act_layer") is not None + else self.act_layer + ) + assert ba["act_layer"] is not None + ba["norm_layer"] = self.norm_layer + ba["drop_path_rate"] = drop_path_rate + ba["data_format"] = self.data_format + ba["channel_axis"] = self.channel_axis + ba["dtype"] = self.dtype + if bt in ("ir", "er", "uir", "ds", "dsa"): + se_ratio = ba.pop("se_ratio", None) + if se_ratio and self.se_layer is not None: + if not self.se_from_exp: + se_ratio /= ba.get("exp_ratio", 1.0) + if s2d == 1: + se_ratio /= 4 + ba["se_layer"] = lambda channels: self.se_layer( + filters=channels, + bottleneck_filters=adjust_channels(channels * se_ratio), + squeeze_activation=ba["act_layer"], + excite_activation="sigmoid", + data_format=self.data_format, + dtype=self.dtype, + ) + else: + ba["se_layer"] = None + ba.pop("aa_layer", None) + if bt == "ir": + block = ( + CondConvResidual(**ba) + if ba.get("num_experts", 0) > 0 + else InvertedResidualBlock( + expansion=ba["exp_ratio"], + infilters=ba["in_chs"], + filters=ba["out_chs"], + kernel_size=ba["dw_kernel_size"], + stride=ba["stride"], + squeeze_excite_ratio=ba.pop("se_ratio", None), + activation=ba["act_layer"], + ) + ) + elif bt == "ds" or bt == "dsa": + block = DepthwiseConvBlock( + infilters=ba["in_chs"], + filters=ba["out_chs"], + kernel_size=ba["dw_kernel_size"], + stride=ba["stride"], + squeeze_excite_ratio=ba.pop("se_ratio", None), + residual=not ba["noskip"], + dtype=self.dtype, + ) + elif bt == "er": + block = EdgeResidual(**ba) + elif bt == "cn": + block = ConvBnActBlock(**ba) + elif bt == "uir": + block = UniversalInvertedResidual( + **ba, layer_scale_init_value=self.layer_scale_init_value + ) + elif bt == "mqa": + ba.pop("act_layer", None) + block = MobileAttention( + **ba, + use_multi_query=True, + layer_scale_init_value=self.layer_scale_init_value, + ) + elif bt == "mha": + ba.pop("act_layer", None) + block = MobileAttention( + **ba, layer_scale_init_value=self.layer_scale_init_value + ) + else: + raise ValueError(f"Unknown block type ({bt}) while building model.") + self.in_chs = ba["out_chs"] + return block + + def __call__(self, in_chs, model_block_args): + self.in_chs = in_chs + total_block_count = sum([len(x) for x in model_block_args]) + total_block_idx = 0 + current_stride = 2 + current_dilation = 1 + stages = [] + if model_block_args[0][0]["stride"] > 1: + feature_info = dict( + module="conv_stem", + num_chs=in_chs, + stage=0, + reduction=current_stride, + ) + self.features.append(feature_info) + space2depth = 0 + for stack_idx, stack_args in enumerate(model_block_args): + blocks = [] + for block_idx, block_args in enumerate(stack_args): + last_block = block_idx + 1 == len(stack_args) + in_chs_for_current_block = self.in_chs + assert block_args["stride"] in (1, 2) + if block_idx >= 1: + block_args["stride"] = 1 + if not space2depth and block_args.pop("s2d", False): + assert block_args["stride"] == 1 + space2depth = 1 + if space2depth > 0: + if space2depth == 2 and block_args["stride"] == 2: + block_args["stride"] = 1 + block_args["exp_ratio"] /= 4 + space2depth = 0 + else: + block_args["s2d"] = space2depth + next_dilation = current_dilation + if block_args["stride"] > 1: + next_output_stride = current_stride * block_args["stride"] + if next_output_stride > self.output_stride: + next_dilation = current_dilation * block_args["stride"] + block_args["stride"] = 1 + else: + current_stride = next_output_stride + block_args["dilation"] = current_dilation + if next_dilation != current_dilation: + current_dilation = next_dilation + block = self._make_block( + block_args.copy(), total_block_idx, total_block_count + ) + blocks.append(block) + if space2depth == 1: + space2depth = 2 + extract_features = False + if last_block: + next_stack_idx = stack_idx + 1 + extract_features = ( + next_stack_idx >= len(model_block_args) + or model_block_args[next_stack_idx][0]["stride"] > 1 + ) + if extract_features: + num_chs = 0 + module_name = f"blocks.{stack_idx}.{block_idx}" + if self.feature_location == "expansion": + bt = block_args.get("block_type") + if bt in ["ir", "er", "uir"]: + exp_ratio = block_args.get("exp_ratio", 1.0) + num_chs = self.round_chs_fn( + in_chs_for_current_block * exp_ratio + ) + else: + num_chs = in_chs_for_current_block + else: + num_chs = self.in_chs + module_name = f"blocks.{stack_idx}" + + feature_info = dict( + stage=stack_idx + 1, + reduction=current_stride, + num_chs=num_chs, + module=module_name, + ) + self.features.append(feature_info) + total_block_idx += 1 + stages.append(blocks) + return stages diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py new file mode 100644 index 0000000000..5bd9aeb943 --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier.py @@ -0,0 +1,139 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier_preprocessor import ( # noqa: E501 + MobileNetV5ImageClassifierPreprocessor, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct +from keras_hub.src.models.mobilenetv5.mobilenetv5_utils import ( + SelectAdaptivePool2d, +) +from keras_hub.src.models.task import Task + + +@keras_hub_export("keras_hub.models.MobileNetV5ImageClassifier") +class MobileNetV5ImageClassifier(ImageClassifier): + """An end-to-end MobileNetV5 model for image classification. + + This model attaches a classification head to a `MobileNetV5Backbone`. + The head consists of a global pooling layer, an optional convolutional + head, a dropout layer, and a final dense classifier layer. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to image inputs during + `fit()`, `predict()`, and `evaluate()`. + + Args: + backbone: A `keras_hub.models.MobileNetV5Backbone` instance. + num_classes: int. The number of classes for the classification head. + preprocessor: A `keras_hub.models.ImageClassifierPreprocessor` or + `None`. If `None`, this model will not apply preprocessing. + head_hidden_size: int. The number of channels in the convolutional + head. + global_pool: str. The type of global pooling to use. + drop_rate: float. The dropout rate for the head. + head_dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to + use for the head computations and weights. + + Example: + ```python + import keras + from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import ( + decode_arch_def + ) + + arch_def = [["er_r1_k3_s2_e4_c24"], ["uir_r2_k5_s2_e6_c48"]] + block_args = decode_arch_def(arch_def) + backbone = keras_hub.models.MobileNetV5Backbone(block_args=block_args) + model = keras_hub.models.MobileNetV5ImageClassifier(backbone, 1000) + images = keras.ops.ones((1, 224, 224, 3)) + output = model.predict(images) + ``` + """ + + backbone_cls = MobileNetV5Backbone + preprocessor_cls = MobileNetV5ImageClassifierPreprocessor + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + head_hidden_size=2048, + global_pool="avg", + drop_rate=0.0, + head_dtype=None, + **kwargs, + ): + head_dtype = head_dtype or backbone.dtype_policy + data_format = getattr(backbone, "data_format", "channels_last") + + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + if backbone.use_msfa: + self.global_pool = SelectAdaptivePool2d( + pool_type=global_pool, data_format=data_format, flatten=True + ) + self.conv_head = None + self.flatten = None + else: + self.global_pool = SelectAdaptivePool2d( + pool_type=global_pool, data_format=data_format, flatten=False + ) + self.conv_head = ConvNormAct( + out_chs=head_hidden_size, + kernel_size=1, + pad_type="", + norm_layer=backbone.norm_layer, + act_layer=backbone.act_layer, + bias=False, + name="conv_head", + dtype=head_dtype, + ) + self.flatten = keras.layers.Flatten(dtype=head_dtype) + self.dropout = ( + keras.layers.Dropout(drop_rate, dtype=head_dtype) + if drop_rate > 0.0 + else None + ) + self.classifier = ( + keras.layers.Dense(num_classes, dtype=head_dtype, name="classifier") + if num_classes > 0 + else keras.layers.Activation("linear", name="identity_classifier") + ) + + # === Functional Model === + inputs = self.backbone.input + x = self.backbone(inputs) + x = self.global_pool(x) + if self.conv_head is not None: + x = self.conv_head(x) + if self.flatten is not None: + x = self.flatten(x) + if self.dropout is not None: + x = self.dropout(x) + outputs = self.classifier(x) + Task.__init__(self, inputs=inputs, outputs=outputs, **kwargs) + + # === Config === + self.num_classes = num_classes + self.head_hidden_size = head_hidden_size + self.global_pool_type = global_pool + self.drop_rate = drop_rate + + def get_config(self): + config = Task.get_config(self) + config.update( + { + "num_classes": self.num_classes, + "head_hidden_size": self.head_hidden_size, + "global_pool": self.global_pool_type, + "drop_rate": self.drop_rate, + } + ) + return config diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py new file mode 100644 index 0000000000..cac991c9a8 --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_preprocessor.py @@ -0,0 +1,16 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_converter import ( + MobileNetV5ImageConverter, +) + + +@keras_hub_export("keras_hub.models.MobileNetV5ImageClassifierPreprocessor") +class MobileNetV5ImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = MobileNetV5Backbone + image_converter_cls = MobileNetV5ImageConverter diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_test.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_test.py new file mode 100644 index 0000000000..26502c29da --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_classifier_test.py @@ -0,0 +1,66 @@ +import numpy as np +import pytest + +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import decode_arch_def +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier import ( + MobileNetV5ImageClassifier, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier_preprocessor import ( # noqa: E501 + MobileNetV5ImageClassifierPreprocessor, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_converter import ( + MobileNetV5ImageConverter, +) +from keras_hub.src.tests.test_case import TestCase + + +class MobileNetV5ImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 32, 32, 3), dtype="float32") + self.labels = [0, 9] # num_classes = 10 + arch_def = [ + ["er_r1_k3_s2_e4_c24"], + ["uir_r2_k5_s2_e6_c48"], + ] + block_args = decode_arch_def(arch_def) + self.backbone = MobileNetV5Backbone( + block_args=block_args, + input_shape=(32, 32, 3), + stem_size=16, + use_msfa=False, + ) + self.image_converter = MobileNetV5ImageConverter( + height=32, width=32, scale=1 / 255.0 + ) + self.preprocessor = MobileNetV5ImageClassifierPreprocessor( + self.image_converter + ) + self.init_kwargs = { + "backbone": self.backbone, + "preprocessor": self.preprocessor, + "num_classes": 10, + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + self.run_task_test( + cls=MobileNetV5ImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 10), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetV5ImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py new file mode 100644 index 0000000000..486e4c1da7 --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_image_converter.py @@ -0,0 +1,10 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) + + +@keras_hub_export("keras_hub.layers.MobileNetV5ImageConverter") +class MobileNetV5ImageConverter(ImageConverter): + backbone_cls = MobileNetV5Backbone diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py new file mode 100644 index 0000000000..b21f08e432 --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py @@ -0,0 +1,453 @@ +import keras + +from keras_hub.src.models.mobilenet.util import adjust_channels + + +class DropPath(keras.layers.Layer): + """Implements the DropPath layer. + + DropPath is a form of stochastic depth, where connections are randomly + dropped during training. + + Args: + drop_prob: float. The probability of dropping a path. + scale_by_keep: bool. If `True`, scale the output by `1/keep_prob`. + """ + + def __init__(self, drop_prob=0.0, scale_by_keep=True, dtype=None, **kwargs): + super().__init__(dtype=dtype, **kwargs) + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def call(self, x, training=False): + if self.drop_prob == 0.0 or not training: + return x + keep_prob = 1.0 - self.drop_prob + shape = (keras.ops.shape(x)[0],) + (1,) * (len(x.shape) - 1) + random_tensor = keep_prob + keras.random.uniform( + shape, 0, 1, dtype=x.dtype + ) + random_tensor = keras.ops.floor(random_tensor) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor = random_tensor / keep_prob + return x * random_tensor + + def get_config(self): + config = super().get_config() + config.update( + {"drop_prob": self.drop_prob, "scale_by_keep": self.scale_by_keep} + ) + return config + + +class LayerScale2d(keras.layers.Layer): + """A layer that applies a learnable scaling factor to the input tensor. + + This layer scales the input tensor by a learnable `gamma` parameter. The + scaling is applied channel-wise. + + Args: + dim: int. The number of channels in the input tensor. + init_values: float. The initial value for the `gamma` parameter. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + dim, + init_values=1e-5, + data_format=None, + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.dim = dim + self.init_values = init_values + self.data_format = data_format + self.channel_axis = channel_axis + + def build(self, input_shape): + self.gamma = self.add_weight( + shape=(self.dim,), + initializer=keras.initializers.Constant(self.init_values), + trainable=True, + name="gamma", + ) + super().build(input_shape) + + def call(self, x): + if self.data_format == "channels_first": + gamma = keras.ops.reshape(self.gamma, (1, self.dim, 1, 1)) + else: + gamma = keras.ops.reshape(self.gamma, (1, 1, 1, self.dim)) + return x * gamma + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "init_values": self.init_values, + "data_format": self.data_format, + "channel_axis": self.channel_axis, + } + ) + return config + + +class RmsNorm2d(keras.layers.Layer): + """A layer that applies Root Mean Square Normalization to a 2D input. + + This layer normalizes the input tensor along the channel dimension using + the root mean square of the values, and then scales it by a learnable + `gamma` parameter. + + Args: + dim: int. The number of channels in the input tensor. + eps: float. A small epsilon value to avoid division by zero. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + dim, + eps=1e-6, + data_format=None, + channel_axis=None, + gamma_initializer="ones", + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.dim = dim + self.eps = eps + self.data_format = data_format + self.channel_axis = channel_axis + self.gamma_initializer = gamma_initializer + + def build(self, input_shape): + self.gamma = self.add_weight( + shape=(self.dim,), + initializer=self.gamma_initializer, + trainable=True, + name="gamma", + ) + super().build(input_shape) + + def call(self, x): + input_dtype = x.dtype + if self.data_format == "channels_first": + x_permuted = keras.ops.transpose(x, (0, 2, 3, 1)) + else: + x_permuted = x + x_float = keras.ops.cast(x_permuted, "float32") + norm_factor = keras.ops.rsqrt( + keras.ops.mean(keras.ops.square(x_float), axis=-1, keepdims=True) + + self.eps + ) + norm_x_float = x_float * norm_factor + norm_x = keras.ops.cast(norm_x_float, input_dtype) + scaled_x = norm_x * self.gamma + if self.data_format == "channels_first": + output = keras.ops.transpose(scaled_x, (0, 3, 1, 2)) + else: + output = scaled_x + return output + + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "eps": self.eps, + "data_format": self.data_format, + "channel_axis": self.channel_axis, + "gamma_initializer": self.gamma_initializer, + } + ) + return config + + +class ConvNormAct(keras.layers.Layer): + """A layer that combines convolution, normalization, and activation. + + This layer provides a convenient way to create a sequence of a 2D + convolution, a normalization layer, and an activation function. + + Args: + out_chs: int. The number of output channels. + kernel_size: int or tuple. The size of the convolution kernel. + stride: int or tuple. The stride of the convolution. + dilation: int or tuple. The dilation rate of the convolution. + groups: int. The number of groups for a grouped convolution. + bias: bool. If `True`, a bias term is used in the convolution. + pad_type: str. The type of padding to use. `"same"` or `""` for same + padding, otherwise valid padding. + apply_act: bool. If `True`, an activation function is applied. + act_layer: str. The name of the activation function to use. + norm_layer: str. The name of the normalization layer to use. + Supported values are `"batch_norm"` and `"rms_norm"`. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + out_chs, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=False, + pad_type="same", + apply_act=True, + act_layer="relu", + norm_layer="batch_norm", + data_format=None, + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.channel_axis = channel_axis + self.data_format = data_format + self.pad = None + self.kernel_initializer = keras.initializers.VarianceScaling( + scale=2.0, mode="fan_out", distribution="untruncated_normal" + ) + self.bias_initializer = "zeros" + padding_mode = "valid" + if pad_type.lower() == "" or pad_type.lower() == "same": + if stride > 1: + self.pad = keras.layers.ZeroPadding2D( + padding=kernel_size // 2, + data_format=self.data_format, + dtype=self.dtype_policy, + ) + else: + padding_mode = "same" + + self.conv = keras.layers.Conv2D( + out_chs, + kernel_size, + strides=stride, + padding=padding_mode, + dilation_rate=dilation, + groups=groups, + use_bias=bias, + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + + if norm_layer == "batch_norm": + self.norm = keras.layers.BatchNormalization( + axis=self.channel_axis, + epsilon=1e-5, + gamma_initializer="ones", + beta_initializer="zeros", + dtype=self.dtype_policy, + ) + elif norm_layer == "rms_norm": + self.norm = RmsNorm2d( + out_chs, + data_format=self.data_format, + channel_axis=self.channel_axis, + gamma_initializer="ones", + dtype=self.dtype_policy, + ) + else: + ln_axis = [1, 2, 3] + if self.data_format == "channels_first": + ln_axis = [2, 3, 1] + self.norm = keras.layers.LayerNormalization( + axis=ln_axis, + dtype=self.dtype_policy, + ) + + self.apply_act = apply_act + if self.apply_act: + if act_layer == "gelu": + self.act = keras.layers.Activation( + lambda x: keras.activations.gelu(x, approximate=True), + dtype=self.dtype_policy, + ) + else: + self.act = keras.layers.Activation( + act_layer, + dtype=self.dtype_policy, + ) + + def build(self, input_shape): + if self.pad: + self.pad.build(input_shape) + conv_input_shape = self.pad.compute_output_shape(input_shape) + else: + conv_input_shape = input_shape + + self.conv.build(conv_input_shape) + conv_output_shape = self.conv.compute_output_shape(conv_input_shape) + self.norm.build(conv_output_shape) + if self.apply_act: + self.act.build(conv_output_shape) + self.built = True + + def call(self, x, training=False): + if self.pad: + x = self.pad(x) + x = self.conv(x) + x = self.norm(x, training=training) + if self.apply_act: + x = self.act(x) + return x + + def compute_output_shape(self, input_shape): + if self.pad: + padded_shape = self.pad.compute_output_shape(input_shape) + return self.conv.compute_output_shape(padded_shape) + else: + return self.conv.compute_output_shape(input_shape) + + +class SEModule(keras.layers.Layer): + """Implements the Squeeze-and-Excitation (SE) module. + + The SE module adaptively recalibrates channel-wise feature responses by + explicitly modeling interdependencies between channels. + + Args: + channels: int. The number of input channels. + rd_ratio: float. The reduction ratio for the bottleneck channels. + rd_channels: int. The number of bottleneck channels. If specified, + `rd_ratio` is ignored. + rd_divisor: int. The divisor for rounding the number of bottleneck + channels. + add_maxpool: bool. If `True`, max pooling is used in addition to + average pooling for the squeeze operation. + bias: bool. If `True`, bias terms are used in the fully connected + layers. + act_layer: str. The activation function for the bottleneck layer. + norm_layer: str. The normalization layer to use. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + gate_layer: str. The gating activation function. + """ + + def __init__( + self, + channels, + rd_ratio=1.0 / 16, + rd_channels=None, + rd_divisor=8, + add_maxpool=False, + bias=True, + act_layer="relu", + norm_layer=None, + data_format=None, + channel_axis=None, + gate_layer="sigmoid", + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.channels = channels + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = adjust_channels( + channels * rd_ratio, rd_divisor, round_limit=0.0 + ) + self.rd_ratio = rd_ratio + self.rd_channels = rd_channels + self.rd_divisor = rd_divisor + self.bias = bias + self.act_layer_arg = act_layer + self.kernel_initializer = keras.initializers.VarianceScaling( + scale=2.0, mode="fan_out", distribution="untruncated_normal" + ) + self.bias_initializer = "zeros" + self.norm_layer_arg = norm_layer + self.gate_layer_arg = gate_layer + self.data_format = data_format + self.channel_axis = channel_axis + self.mean_axis = [2, 3] if data_format == "channels_first" else [1, 2] + self.fc1 = keras.layers.Conv2D( + rd_channels, + kernel_size=1, + use_bias=bias, + name="fc1", + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + self.bn = ( + keras.layers.BatchNormalization( + axis=channel_axis, dtype=self.dtype_policy + ) + if norm_layer + else (lambda x, training: x) + ) + self.act = keras.layers.Activation(act_layer, dtype=self.dtype_policy) + self.fc2 = keras.layers.Conv2D( + channels, + kernel_size=1, + use_bias=bias, + name="fc2", + data_format=self.data_format, + kernel_initializer=self.kernel_initializer, + bias_initializer=self.bias_initializer, + dtype=self.dtype_policy, + ) + self.gate = keras.layers.Activation(gate_layer, dtype=self.dtype_policy) + + def build(self, input_shape): + self.fc1.build(input_shape) + fc1_output_shape = self.fc1.compute_output_shape(input_shape) + if hasattr(self.bn, "build"): + self.bn.build(fc1_output_shape) + self.act.build(fc1_output_shape) + self.fc2.build(fc1_output_shape) + self.built = True + + def call(self, x, training=False): + x_se = keras.ops.mean(x, axis=self.mean_axis, keepdims=True) + if self.add_maxpool: + x_se = 0.5 * x_se + 0.5 * keras.ops.max( + x, axis=self.mean_axis, keepdims=True + ) + x_se = self.fc1(x_se) + x_se = self.bn(x_se, training=training) + x_se = self.act(x_se) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + def get_config(self): + config = super().get_config() + config.update( + { + "channels": self.channels, + "rd_ratio": self.rd_ratio, + "rd_channels": self.rd_channels, + "rd_divisor": self.rd_divisor, + "add_maxpool": self.add_maxpool, + "bias": self.bias, + "act_layer": self.act_layer_arg, + "norm_layer": self.norm_layer_arg, + "gate_layer": self.gate_layer_arg, + "data_format": self.data_format, + "channel_axis": self.channel_axis, + } + ) + return config diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py new file mode 100644 index 0000000000..3cc33c6c51 --- /dev/null +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_utils.py @@ -0,0 +1,146 @@ +import keras + +from keras_hub.src.models.mobilenet.util import adjust_channels + + +def num_groups(group_size, channels): + if not group_size: + return 1 + else: + if channels % group_size != 0: + raise ValueError( + f"Number of channels ({channels}) must be divisible by " + "group size ({group_size})." + ) + return channels // group_size + + +def parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split(".")] + + +def round_channels( + channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9 +): + if not multiplier: + return channels + return adjust_channels(channels * multiplier, divisor, channel_min) + + +def feature_take_indices(num_stages, indices): + if not isinstance(indices, (tuple, list)): + indices = (indices,) + if any(i < 0 for i in indices): + indices = [i if i >= 0 else num_stages + i for i in indices] + return indices, max(indices) + + +class SelectAdaptivePool2d(keras.layers.Layer): + """A layer that selects and applies a 2D adaptive pooling strategy. + + This layer supports various pooling types like average, max, or a + combination of both. It can also flatten the output. + + Args: + pool_type: str. The type of pooling to apply. One of `"avg"`, `"max"`, + `"avgmax"`, `"catavgmax"`, or `""` (identity). + flatten: bool. If `True`, the output is flattened after pooling. + data_format: str. The format of the input data, either + `"channels_last"` or `"channels_first"`. + channel_axis: int. The axis representing the channels in the input + tensor. + """ + + def __init__( + self, + pool_type="avg", + flatten=False, + data_format=None, + channel_axis=None, + dtype=None, + **kwargs, + ): + super().__init__(dtype=dtype, **kwargs) + self.pool_type = pool_type.lower() + self.flatten = flatten + self.data_format = data_format + self.channels_axis = channel_axis + self.pool = None + self.pool_avg = None + self.pool_max = None + self.pool_cat = None + self.flatten_layer = None + if self.pool_type not in ("avg", "max", "avgmax", "catavgmax", ""): + raise ValueError(f"Invalid pool type: {self.pool_type}") + + def build(self, input_shape): + if self.pool_type == "avg": + self.pool = keras.layers.GlobalAveragePooling2D( + data_format=self.data_format, + keepdims=not self.flatten, + dtype=self.dtype_policy, + ) + elif self.pool_type == "max": + self.pool = keras.layers.GlobalMaxPooling2D( + data_format=self.data_format, + keepdims=not self.flatten, + dtype=self.dtype_policy, + ) + elif self.pool_type in ("avgmax", "catavgmax"): + self.pool_avg = keras.layers.GlobalAveragePooling2D( + data_format=self.data_format, + keepdims=not self.flatten, + dtype=self.dtype_policy, + ) + self.pool_max = keras.layers.GlobalMaxPooling2D( + data_format=self.data_format, + keepdims=not self.flatten, + dtype=self.dtype_policy, + ) + if self.pool_type == "catavgmax": + axis = 1 if self.data_format == "channels_first" else -1 + self.pool_cat = keras.layers.Concatenate( + axis=axis, dtype=self.dtype_policy + ) + elif not self.pool_type: + self.pool = keras.layers.Identity(dtype=self.dtype_policy) + if self.flatten: + self.flatten_layer = keras.layers.Flatten( + dtype=self.dtype_policy + ) + super().build(input_shape) + + def call(self, x): + if self.pool_type in ("avg", "max"): + return self.pool(x) + elif self.pool_type == "avgmax": + x_avg = self.pool_avg(x) + x_max = self.pool_max(x) + return 0.5 * (x_avg + x_max) + elif self.pool_type == "catavgmax": + x_avg = self.pool_avg(x) + x_max = self.pool_max(x) + return self.pool_cat([x_avg, x_max]) + elif not self.pool_type: + x = self.pool(x) + if self.flatten_layer: + x = self.flatten_layer(x) + return x + return x + + def feat_mult(self): + return 2 if self.pool_type == "catavgmax" else 1 + + def get_config(self): + config = super().get_config() + config.update( + { + "pool_type": self.pool_type, + "flatten": self.flatten, + "data_format": self.data_format, + } + ) + return config diff --git a/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py b/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py new file mode 100644 index 0000000000..a7364d5fd5 --- /dev/null +++ b/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py @@ -0,0 +1,469 @@ +import os +import shutil +import types + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +from keras_hub.src.models.mobilenetv5.mobilenetv5_attention import ( + MobileAttention, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import EdgeResidual +from keras_hub.src.models.mobilenetv5.mobilenetv5_blocks import ( + UniversalInvertedResidual, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_builder import decode_arch_def +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier import ( + MobileNetV5ImageClassifier, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier_preprocessor import ( # noqa: E501 + MobileNetV5ImageClassifierPreprocessor, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_converter import ( + MobileNetV5ImageConverter, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct + +PRESET_MAP = { + "mobilenetv5_300m_enc.gemma3n": { + "arch": "mobilenetv5_300m_enc", + "hf_hub_id": "timm/mobilenetv5_300m.gemma3n", + } +} + +MODEL_CONFIGS = { + "mobilenetv5_300m_enc.gemma3n": { + "backbone": { + "block_args": decode_arch_def( + [ + [ + "er_r1_k3_s2_e4_c128", + "er_r1_k3_s1_e4_c128", + "er_r1_k3_s1_e4_c128", + ], + [ + "uir_r1_a3_k5_s2_e6_c256", + "uir_r1_a5_k0_s1_e4_c256", + "uir_r1_a3_k0_s1_e4_c256", + "uir_r1_a5_k0_s1_e4_c256", + "uir_r1_a3_k0_s1_e4_c256", + ], + [ + "uir_r1_a5_k5_s2_e6_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a5_k0_s1_e4_c640", + "uir_r1_a0_k0_s1_e1_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + "mqa_r1_k3_h12_v2_s1_d64_c640", + "uir_r1_a0_k0_s1_e2_c640", + ], + [ + "uir_r1_a5_k5_s2_e6_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", + ], + ] + ), + "stem_size": 64, + "num_features": 2048, + "norm_layer": "rms_norm", + "act_layer": "gelu", + "use_msfa": True, + "layer_scale_init_value": 1e-5, + }, + "classifier": { + "num_classes": 0, + }, + } +} + + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + "preset", + None, + "Must be a valid `MobileNetV5` preset.", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + "Optional Kaggle URI to upload the converted model preset.", + required=False, +) + + +class TimmToKerasConverter: + def __init__(self, timm_model): + self.state_dict = { + k: v.cpu().numpy() for k, v in timm_model.state_dict().items() + } + + def convert(self, keras_model: MobileNetV5ImageClassifier): + print("🔶 Starting weight conversion...") + backbone = keras_model.backbone + self._port_stem(backbone) + self._port_blocks(backbone) + self._port_msfa(backbone) + print("✅ Backbone weights converted.") + + def _port_weights(self, layer, timm_key, transpose_dims=None): + if timm_key not in self.state_dict: + print(f"⚠️ Weight key not found in state_dict: {timm_key}") + return + weights = self.state_dict[timm_key] + if transpose_dims: + weights = weights.transpose(transpose_dims) + + current_weights = layer.get_weights() + if len(current_weights) == 1: + layer.set_weights([weights]) + elif len(current_weights) == 2: + bias_key = timm_key.replace(".weight", ".bias") + if bias_key in self.state_dict: + bias = self.state_dict[bias_key] + layer.set_weights([weights, bias]) + else: + layer.set_weights([weights, current_weights[1]]) + else: + print(f"❓ Unexpected number of weights in layer {layer.name}") + + def _port_bn(self, layer, timm_prefix): + weights = [ + self.state_dict[f"{timm_prefix}.weight"], + self.state_dict[f"{timm_prefix}.bias"], + self.state_dict[f"{timm_prefix}.running_mean"], + self.state_dict[f"{timm_prefix}.running_var"], + ] + layer.set_weights(weights) + + def _port_rms_norm(self, layer, timm_prefix): + layer.set_weights([self.state_dict[f"{timm_prefix}.weight"]]) + + def _port_cna( + self, cna_layer: ConvNormAct, timm_conv_prefix, timm_norm_prefix + ): + if isinstance(cna_layer.conv, keras.layers.DepthwiseConv2D): + self._port_weights( + cna_layer.conv, + f"{timm_conv_prefix}.weight", + transpose_dims=(2, 3, 0, 1), + ) + else: + self._port_weights( + cna_layer.conv, + f"{timm_conv_prefix}.weight", + transpose_dims=(2, 3, 1, 0), + ) + if f"{timm_norm_prefix}.running_mean" in self.state_dict: + self._port_bn(cna_layer.norm, timm_norm_prefix) + else: + self._port_rms_norm(cna_layer.norm, timm_norm_prefix) + + def _port_stem(self, backbone): + print(" -> Porting stem...") + stem_layer = backbone.get_layer("conv_stem") + self._port_cna(stem_layer, "conv_stem.conv", "conv_stem.bn") + + def _port_msfa(self, backbone): + print(" -> Porting MSFA...") + try: + msfa_layer = backbone.get_layer("msfa") + ffn = msfa_layer.ffn + self._port_cna( + ffn.pw_exp, "msfa.ffn.pw_exp.conv", "msfa.ffn.pw_exp.bn" + ) + self._port_cna( + ffn.pw_proj, "msfa.ffn.pw_proj.conv", "msfa.ffn.pw_proj.bn" + ) + self._port_rms_norm(msfa_layer.norm, "msfa.norm") + except ValueError: + print(" -> MSFA layer not found, skipping.") + + def _port_blocks(self, backbone): + print(" -> Porting blocks...") + stack_idx = 0 + while True: + try: + stack = backbone.get_layer(f"stack_{stack_idx}") + print(f" -> Stack {stack_idx}") + for block_idx, block in enumerate(stack.layers): + timm_prefix = f"blocks.{stack_idx}.{block_idx}" + if isinstance(block, EdgeResidual): + self._port_cna( + block.conv_exp, + f"{timm_prefix}.conv_exp", + f"{timm_prefix}.bn1", + ) + self._port_cna( + block.conv_pwl, + f"{timm_prefix}.conv_pwl", + f"{timm_prefix}.bn2", + ) + self._port_bn(block.conv_pwl.norm, f"{timm_prefix}.bn2") + elif isinstance(block, UniversalInvertedResidual): + if hasattr(block, "dw_start") and not isinstance( + block.dw_start, types.FunctionType + ): + self._port_cna( + block.dw_start, + f"{timm_prefix}.dw_start.conv", + f"{timm_prefix}.dw_start.bn", + ) + self._port_cna( + block.pw_exp, + f"{timm_prefix}.pw_exp.conv", + f"{timm_prefix}.pw_exp.bn", + ) + if hasattr(block, "dw_mid") and not isinstance( + block.dw_mid, types.FunctionType + ): + self._port_cna( + block.dw_mid, + f"{timm_prefix}.dw_mid.conv", + f"{timm_prefix}.dw_mid.bn", + ) + self._port_cna( + block.pw_proj, + f"{timm_prefix}.pw_proj.conv", + f"{timm_prefix}.pw_proj.bn", + ) + self._port_weights( + block.layer_scale, + f"{timm_prefix}.layer_scale.gamma", + ) + elif isinstance(block, MobileAttention): + self._port_rms_norm(block.norm, f"{timm_prefix}.norm") + self._port_weights( + block.layer_scale, + f"{timm_prefix}.layer_scale.gamma", + ) + attn_prefix = f"{timm_prefix}.attn" + self._port_attn(block.attn, attn_prefix) + + stack_idx += 1 + except ValueError: + break + + def _port_attn(self, attn_layer, attn_prefix): + self._port_weights( + attn_layer.query_layers[-1], + f"{attn_prefix}.query.proj.weight", + (2, 3, 1, 0), + ) + if len(attn_layer.key_layers) > 1: + self._port_weights( + attn_layer.key_layers[1], + f"{attn_prefix}.key.down_conv.weight", + (2, 3, 0, 1), + ) + self._port_bn(attn_layer.key_layers[2], f"{attn_prefix}.key.norm") + self._port_weights( + attn_layer.key_layers[-1], + f"{attn_prefix}.key.proj.weight", + (2, 3, 1, 0), + ) + if len(attn_layer.value_layers) > 1: + self._port_weights( + attn_layer.value_layers[1], + f"{attn_prefix}.value.down_conv.weight", + (2, 3, 0, 1), + ) + self._port_bn( + attn_layer.value_layers[2], f"{attn_prefix}.value.norm" + ) + self._port_weights( + attn_layer.value_layers[-1], + f"{attn_prefix}.value.proj.weight", + (2, 3, 1, 0), + ) + self._port_weights( + attn_layer.output_proj_layers[-2], + f"{attn_prefix}.output.proj.weight", + (2, 3, 1, 0), + ) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + + # Preprocess with Keras. + batch = keras.ops.cast(batch, "float32") + keras_preprocessed = keras_model.preprocessor(batch) + + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + # Apply global average pooling to the Timm output to match Keras's output + timm_outputs_pooled = np.mean(timm_outputs, axis=(2, 3)) + + print("🔶 Keras output:", keras_outputs[0, :10]) + print("🔶 TIMM output (pooled):", timm_outputs_pooled[0, :10]) + print("🔶 Keras label:", keras_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs_pooled)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean( + np.abs(np.array(keras_preprocessed) - np.array(timm_preprocessed)) + ) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + preset = FLAGS.preset + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + timm_config = PRESET_MAP[preset] + timm_arch = timm_config["arch"] + hf_hub_id = timm_config["hf_hub_id"] + print(f"✅ Loading TIMM model: {timm_arch} from {hf_hub_id}") + timm_model = timm.create_model( + timm_arch, + pretrained=True, + pretrained_cfg_overlay=dict(hf_hub_id=hf_hub_id), + ) + timm_model = timm_model.eval() + print("✅ Creating Keras model.") + config = MODEL_CONFIGS[preset] + backbone = MobileNetV5Backbone(**config["backbone"]) + pretrained_cfg = timm_model.pretrained_cfg + image_size = ( + pretrained_cfg["input_size"][1], + pretrained_cfg["input_size"][2], + ) + mean = pretrained_cfg["mean"] + std = pretrained_cfg["std"] + interpolation = pretrained_cfg["interpolation"] + scale = [1 / (255.0 * s) for s in std] if std else 1 / 255.0 + offset = [-m / s for m, s in zip(mean, std)] if mean and std else 0.0 + image_converter = MobileNetV5ImageConverter( + image_size=image_size, + scale=scale, + offset=offset, + interpolation=interpolation, + antialias=True if interpolation == "bicubic" else False, + ) + preprocessor = MobileNetV5ImageClassifierPreprocessor( + image_converter=image_converter + ) + keras_model = MobileNetV5ImageClassifier( + backbone=backbone, + preprocessor=preprocessor, + **config["classifier"], + ) + converter = TimmToKerasConverter(timm_model) + converter.convert(keras_model) + validate_output(keras_model, timm_model) + keras_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + upload_uri = FLAGS.upload_uri + if upload_uri: + try: + import keras_hub + + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + except ImportError: + print("❗ `keras-hub` is not installed. Skipping upload.") + except Exception as e: + print(f"❗ An error occurred during upload: {e}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main) From 9df2587c8a9769ebe20137966ac84421aa935d5f Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Sat, 13 Sep 2025 14:28:48 +0400 Subject: [PATCH 2/5] nit: I like the CI green, ig? --- keras_hub/api/layers/__init__.py | 3 +++ keras_hub/api/models/__init__.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index f90c214d6b..34360763d2 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -105,6 +105,9 @@ from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( MobileNetImageConverter as MobileNetImageConverter, ) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_converter import ( + MobileNetV5ImageConverter as MobileNetV5ImageConverter, +) from keras_hub.src.models.moonshine.moonshine_audio_converter import ( MoonshineAudioConverter as MoonshineAudioConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index fe220e2d43..3bc2abeb0e 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -413,6 +413,15 @@ from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( MobileNetImageClassifierPreprocessor as MobileNetImageClassifierPreprocessor, ) +from keras_hub.src.models.mobilenetv5.mobilenetv5_backbone import ( + MobileNetV5Backbone as MobileNetV5Backbone, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier import ( + MobileNetV5ImageClassifier as MobileNetV5ImageClassifier, +) +from keras_hub.src.models.mobilenetv5.mobilenetv5_image_classifier_preprocessor import ( + MobileNetV5ImageClassifierPreprocessor as MobileNetV5ImageClassifierPreprocessor, +) from keras_hub.src.models.moonshine.moonshine_audio_to_text import ( MoonshineAudioToText as MoonshineAudioToText, ) From dd234debcb262de5d5bfe838a8ca0aca695ec410 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Wed, 24 Sep 2025 23:14:26 +0530 Subject: [PATCH 3/5] fix: Bring the numerics significantly closer --- .../mobilenetv5/mobilenetv5_attention.py | 49 ++++-- .../models/mobilenetv5/mobilenetv5_builder.py | 5 +- .../convert_mobilenetv5_checkpoints.py | 156 ++++++++++-------- 3 files changed, 132 insertions(+), 78 deletions(-) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py index f46a188ecb..5f83a6f626 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py @@ -91,15 +91,23 @@ def __init__( dtype=self.dtype_policy, ) ) - query_layers.append( - norm_layer( + if norm_layer is RmsNorm2d: + norm = norm_layer( + dim=dim, + channel_axis=self.channel_axis, + data_format=self.data_format, + name="query_norm", + dtype=self.dtype_policy, + ) + else: + norm = norm_layer( axis=self.channel_axis, name="query_norm", gamma_initializer="ones", beta_initializer="zeros", dtype=self.dtype_policy, ) - ) + query_layers.append(norm) query_layers.append( keras.layers.Conv2D( filters=self.num_heads * self.key_dim, @@ -138,15 +146,23 @@ def __init__( dtype=self.dtype_policy, ) ) - key_layers.append( - norm_layer( + if norm_layer is RmsNorm2d: + norm = norm_layer( + dim=dim, + channel_axis=self.channel_axis, + data_format=self.data_format, + name="key_norm", + dtype=self.dtype_policy, + ) + else: + norm = norm_layer( axis=self.channel_axis, gamma_initializer="ones", beta_initializer="zeros", name="key_norm", dtype=self.dtype_policy, ) - ) + key_layers.append(norm) key_layers.append( keras.layers.Conv2D( filters=self.key_dim, @@ -186,15 +202,23 @@ def __init__( dtype=self.dtype_policy, ) ) - value_layers.append( - norm_layer( + if norm_layer is RmsNorm2d: + norm = norm_layer( + dim=dim, + channel_axis=self.channel_axis, + data_format=self.data_format, + name="value_norm", + dtype=self.dtype_policy, + ) + else: + norm = norm_layer( axis=self.channel_axis, gamma_initializer="ones", beta_initializer="zeros", name="value_norm", dtype=self.dtype_policy, ) - ) + value_layers.append(norm) value_layers.append( keras.layers.Conv2D( filters=self.value_dim, @@ -512,6 +536,11 @@ def __init__( if num_heads is None: assert in_chs % key_dim == 0 num_heads = in_chs // key_dim + attn_norm_layer = ( + RmsNorm2d + if norm_layer == "rms_norm" + else keras.layers.BatchNormalization + ) if use_multi_query: self.attn = MultiQueryAttention2d( dim=in_chs, @@ -526,7 +555,7 @@ def __init__( padding=pad_type, attn_drop=attn_drop, proj_drop=proj_drop, - norm_layer=keras.layers.BatchNormalization, + norm_layer=attn_norm_layer, use_bias=use_bias, channel_axis=self.channel_axis, data_format=self.data_format, diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py index eb4d7034e0..d7bbe6f163 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py @@ -75,10 +75,13 @@ def decode_block_str(block_str): ) ) elif block_type in ("mqa", "mha"): + key_dim_val = int(options.get("d", "64")) block_args.update( dict( num_heads=int(options.get("h", "12")), - key_dim=int(options.get("d", "64")), + key_dim=key_dim_val, + value_dim=key_dim_val, + kv_stride=int(options.get("v", "1")), use_cpe=bool(int(options.get("cpe", "0"))), ) ) diff --git a/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py b/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py index a7364d5fd5..479eb60ba2 100644 --- a/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py @@ -31,6 +31,7 @@ MobileNetV5ImageConverter, ) from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import ConvNormAct +from keras_hub.src.models.mobilenetv5.mobilenetv5_layers import RmsNorm2d PRESET_MAP = { "mobilenetv5_300m_enc.gemma3n": { @@ -44,11 +45,13 @@ "backbone": { "block_args": decode_arch_def( [ + # Stage 0: 128x128 in [ "er_r1_k3_s2_e4_c128", "er_r1_k3_s1_e4_c128", "er_r1_k3_s1_e4_c128", ], + # Stage 1: 256x256 in [ "uir_r1_a3_k5_s2_e6_c256", "uir_r1_a5_k0_s1_e4_c256", @@ -56,6 +59,7 @@ "uir_r1_a5_k0_s1_e4_c256", "uir_r1_a3_k0_s1_e4_c256", ], + # Stage 2: 640x640 in [ "uir_r1_a5_k5_s2_e6_c640", "uir_r1_a5_k0_s1_e4_c640", @@ -95,6 +99,7 @@ "mqa_r1_k3_h12_v2_s1_d64_c640", "uir_r1_a0_k0_s1_e2_c640", ], + # Stage 3: 1280x1280 in [ "uir_r1_a5_k5_s2_e6_c1280", "mqa_r1_k3_h16_s1_d96_c1280", @@ -133,6 +138,8 @@ "uir_r1_a0_k0_s1_e2_c1280", "mqa_r1_k3_h16_s1_d96_c1280", "uir_r1_a0_k0_s1_e2_c1280", + "mqa_r1_k3_h16_s1_d96_c1280", + "uir_r1_a0_k0_s1_e2_c1280", ], ] ), @@ -202,16 +209,18 @@ def _port_weights(self, layer, timm_key, transpose_dims=None): print(f"❓ Unexpected number of weights in layer {layer.name}") def _port_bn(self, layer, timm_prefix): - weights = [ - self.state_dict[f"{timm_prefix}.weight"], - self.state_dict[f"{timm_prefix}.bias"], - self.state_dict[f"{timm_prefix}.running_mean"], - self.state_dict[f"{timm_prefix}.running_var"], + keys = [ + f"{timm_prefix}.weight", + f"{timm_prefix}.bias", + f"{timm_prefix}.running_mean", + f"{timm_prefix}.running_var", ] + weights = [self.state_dict[key] for key in keys] layer.set_weights(weights) def _port_rms_norm(self, layer, timm_prefix): - layer.set_weights([self.state_dict[f"{timm_prefix}.weight"]]) + key = f"{timm_prefix}.weight" + layer.set_weights([self.state_dict[key]]) def _port_cna( self, cna_layer: ConvNormAct, timm_conv_prefix, timm_norm_prefix @@ -253,70 +262,75 @@ def _port_msfa(self, backbone): except ValueError: print(" -> MSFA layer not found, skipping.") - def _port_blocks(self, backbone): + def _port_blocks(self, backbone: MobileNetV5Backbone): print(" -> Porting blocks...") - stack_idx = 0 - while True: - try: - stack = backbone.get_layer(f"stack_{stack_idx}") - print(f" -> Stack {stack_idx}") - for block_idx, block in enumerate(stack.layers): - timm_prefix = f"blocks.{stack_idx}.{block_idx}" - if isinstance(block, EdgeResidual): - self._port_cna( - block.conv_exp, - f"{timm_prefix}.conv_exp", - f"{timm_prefix}.bn1", - ) - self._port_cna( - block.conv_pwl, - f"{timm_prefix}.conv_pwl", - f"{timm_prefix}.bn2", - ) - self._port_bn(block.conv_pwl.norm, f"{timm_prefix}.bn2") - elif isinstance(block, UniversalInvertedResidual): - if hasattr(block, "dw_start") and not isinstance( - block.dw_start, types.FunctionType - ): - self._port_cna( - block.dw_start, - f"{timm_prefix}.dw_start.conv", - f"{timm_prefix}.dw_start.bn", - ) + block_layers = [ + layer + for layer in backbone.layers + if isinstance( + layer, + (EdgeResidual, UniversalInvertedResidual, MobileAttention), + ) + ] + block_counter = 0 + for stack_idx, stack_args in enumerate(backbone.block_args): + print(f" -> Stack {stack_idx}") + for block_idx_in_stage in range(len(stack_args)): + block = block_layers[block_counter] + timm_prefix = f"blocks.{stack_idx}.{block_idx_in_stage}" + if isinstance(block, EdgeResidual): + self._port_cna( + block.conv_exp, + f"{timm_prefix}.conv_exp", + f"{timm_prefix}.bn1", + ) + self._port_cna( + block.conv_pwl, + f"{timm_prefix}.conv_pwl", + f"{timm_prefix}.bn2", + ) + elif isinstance(block, UniversalInvertedResidual): + if hasattr(block, "dw_start") and not isinstance( + block.dw_start, types.FunctionType + ): self._port_cna( - block.pw_exp, - f"{timm_prefix}.pw_exp.conv", - f"{timm_prefix}.pw_exp.bn", + block.dw_start, + f"{timm_prefix}.dw_start.conv", + f"{timm_prefix}.dw_start.bn", ) - if hasattr(block, "dw_mid") and not isinstance( - block.dw_mid, types.FunctionType - ): - self._port_cna( - block.dw_mid, - f"{timm_prefix}.dw_mid.conv", - f"{timm_prefix}.dw_mid.bn", - ) + self._port_cna( + block.pw_exp, + f"{timm_prefix}.pw_exp.conv", + f"{timm_prefix}.pw_exp.bn", + ) + if hasattr(block, "dw_mid") and not isinstance( + block.dw_mid, types.FunctionType + ): self._port_cna( - block.pw_proj, - f"{timm_prefix}.pw_proj.conv", - f"{timm_prefix}.pw_proj.bn", + block.dw_mid, + f"{timm_prefix}.dw_mid.conv", + f"{timm_prefix}.dw_mid.bn", ) - self._port_weights( - block.layer_scale, - f"{timm_prefix}.layer_scale.gamma", + self._port_cna( + block.pw_proj, + f"{timm_prefix}.pw_proj.conv", + f"{timm_prefix}.pw_proj.bn", + ) + gamma_key = f"{timm_prefix}.layer_scale.gamma" + if gamma_key in self.state_dict: + block.layer_scale.set_weights( + [self.state_dict[gamma_key]] ) - elif isinstance(block, MobileAttention): - self._port_rms_norm(block.norm, f"{timm_prefix}.norm") - self._port_weights( - block.layer_scale, - f"{timm_prefix}.layer_scale.gamma", + elif isinstance(block, MobileAttention): + self._port_rms_norm(block.norm, f"{timm_prefix}.norm") + gamma_key = f"{timm_prefix}.layer_scale.gamma" + if gamma_key in self.state_dict: + block.layer_scale.set_weights( + [self.state_dict[gamma_key]] ) - attn_prefix = f"{timm_prefix}.attn" - self._port_attn(block.attn, attn_prefix) - - stack_idx += 1 - except ValueError: - break + attn_prefix = f"{timm_prefix}.attn" + self._port_attn(block.attn, attn_prefix) + block_counter += 1 def _port_attn(self, attn_layer, attn_prefix): self._port_weights( @@ -330,7 +344,11 @@ def _port_attn(self, attn_layer, attn_prefix): f"{attn_prefix}.key.down_conv.weight", (2, 3, 0, 1), ) - self._port_bn(attn_layer.key_layers[2], f"{attn_prefix}.key.norm") + key_norm_layer = attn_layer.key_layers[2] + if isinstance(key_norm_layer, RmsNorm2d): + self._port_rms_norm(key_norm_layer, f"{attn_prefix}.key.norm") + else: + self._port_bn(key_norm_layer, f"{attn_prefix}.key.norm") self._port_weights( attn_layer.key_layers[-1], f"{attn_prefix}.key.proj.weight", @@ -342,9 +360,13 @@ def _port_attn(self, attn_layer, attn_prefix): f"{attn_prefix}.value.down_conv.weight", (2, 3, 0, 1), ) - self._port_bn( - attn_layer.value_layers[2], f"{attn_prefix}.value.norm" - ) + value_norm_layer = attn_layer.value_layers[2] + if isinstance(value_norm_layer, RmsNorm2d): + self._port_rms_norm( + value_norm_layer, f"{attn_prefix}.value.norm" + ) + else: + self._port_bn(value_norm_layer, f"{attn_prefix}.value.norm") self._port_weights( attn_layer.value_layers[-1], f"{attn_prefix}.value.proj.weight", From decb33723b64a61dc388739102f5af5f40f160c4 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Thu, 25 Sep 2025 12:22:03 +0530 Subject: [PATCH 4/5] =?UTF-8?q?=E2=9C=85=20Yayy,=201e-5=20is=20insanely=20?= =?UTF-8?q?good!?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../mobilenetv5/mobilenetv5_attention.py | 26 ++----- .../models/mobilenetv5/mobilenetv5_blocks.py | 69 ++++++++++++------- .../models/mobilenetv5/mobilenetv5_layers.py | 30 ++------ .../convert_mobilenetv5_checkpoints.py | 12 ++-- 4 files changed, 56 insertions(+), 81 deletions(-) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py index 5f83a6f626..ea0e6706fc 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py @@ -73,14 +73,14 @@ def __init__( self.kv_stride = kv_stride self.has_query_strides = any([s > 1 for s in self.query_strides]) self.scale = self.key_dim**-0.5 - self.keras_padding = "same" if padding == "" else "valid" + self.padding = "same" if padding == "" else "valid" self.conv_kernel_initializer = keras.initializers.VarianceScaling( scale=2.0, mode="fan_out", distribution="untruncated_normal" ) self.bias_initializer = "zeros" query_layers = [] if self.has_query_strides: - pool_padding = "valid" if self.keras_padding == "valid" else "same" + pool_padding = "valid" if self.padding == "valid" else "same" query_layers.append( keras.layers.AveragePooling2D( pool_size=self.query_strides, @@ -123,21 +123,12 @@ def __init__( self.query_layers = query_layers key_layers = [] if kv_stride > 1: - if self.keras_padding == "same": - key_layers.append( - keras.layers.ZeroPadding2D( - padding=dw_kernel_size // 2, - data_format=self.data_format, - name="key_down_pad", - dtype=self.dtype_policy, - ) - ) key_layers.append( keras.layers.DepthwiseConv2D( kernel_size=dw_kernel_size, strides=kv_stride, dilation_rate=dilation, - padding="valid", + padding=self.padding, data_format=self.data_format, name="key_down_conv", depthwise_initializer=self.conv_kernel_initializer, @@ -179,21 +170,12 @@ def __init__( self.key_layers = key_layers value_layers = [] if kv_stride > 1: - if self.keras_padding == "same": - value_layers.append( - keras.layers.ZeroPadding2D( - padding=dw_kernel_size // 2, - data_format=self.data_format, - name="value_down_pad", - dtype=self.dtype_policy, - ) - ) value_layers.append( keras.layers.DepthwiseConv2D( kernel_size=dw_kernel_size, strides=kv_stride, dilation_rate=dilation, - padding="valid", + padding=self.padding, data_format=self.data_format, name="value_down_conv", depthwise_initializer=self.conv_kernel_initializer, diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py index 72d76c72d6..f026bb0564 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py @@ -69,7 +69,7 @@ def __init__( self.out_chs = out_chs self.data_format = data_format self.channel_axis = channel_axis - keras_pad_type = "same" if pad_type == "" else "valid" + pad_type = "same" if pad_type == "" else "valid" use_bias = norm_layer == "rms_norm" if dw_kernel_size_start: @@ -79,7 +79,7 @@ def __init__( stride=stride if not dw_kernel_size_mid else 1, dilation=dilation, groups=in_chs, - pad_type=keras_pad_type, + pad_type=pad_type, apply_act=False, act_layer=act_layer, norm_layer=norm_layer, @@ -95,7 +95,7 @@ def __init__( self.pw_exp = ConvNormAct( mid_chs, 1, - pad_type=keras_pad_type, + pad_type=pad_type, act_layer=act_layer, norm_layer=norm_layer, bias=use_bias, @@ -111,7 +111,7 @@ def __init__( stride=stride, dilation=dilation, groups=mid_chs, - pad_type=keras_pad_type, + pad_type=pad_type, act_layer=act_layer, norm_layer=norm_layer, bias=use_bias, @@ -137,7 +137,7 @@ def __init__( self.pw_proj = ConvNormAct( out_chs, 1, - pad_type=keras_pad_type, + pad_type=pad_type, apply_act=False, act_layer=act_layer, norm_layer=norm_layer, @@ -156,7 +156,7 @@ def __init__( else 1, dilation=dilation, groups=out_chs, - pad_type=keras_pad_type, + pad_type=pad_type, apply_act=False, act_layer=act_layer, norm_layer=norm_layer, @@ -289,7 +289,7 @@ def __init__( self.out_chs = out_chs self.data_format = data_format self.channel_axis = channel_axis - keras_pad_type = "same" if pad_type == "" else "valid" + pad_type = "same" if pad_type == "" else "valid" if force_in_chs > 0: mid_chs = adjust_channels(force_in_chs * exp_ratio) else: @@ -302,7 +302,7 @@ def __init__( stride=stride, dilation=dilation, groups=groups, - pad_type=keras_pad_type, + pad_type=pad_type, norm_layer=norm_layer, act_layer=act_layer, bias=use_bias, @@ -326,7 +326,7 @@ def __init__( self.conv_pwl = ConvNormAct( out_chs, pw_kernel_size, - pad_type=keras_pad_type, + pad_type=pad_type, apply_act=False, norm_layer=norm_layer, act_layer=act_layer, @@ -420,7 +420,7 @@ def __init__( ) self.bias_initializer = "zeros" mid_chs = adjust_channels(in_chs * exp_ratio) - keras_pad_type = "same" if pad_type == "" else "valid" + pad_type = "same" if pad_type == "" else "valid" self.routing_fn = keras.layers.Dense( self.num_experts, dtype=self.dtype_policy, @@ -434,7 +434,7 @@ def __init__( keras.layers.Conv2D( filters=mid_chs, kernel_size=exp_kernel_size, - padding=keras_pad_type, + padding=pad_type, use_bias=True, data_format=self.data_format, name=f"conv_pw_expert_{i}", @@ -448,7 +448,7 @@ def __init__( keras.layers.DepthwiseConv2D( kernel_size=dw_kernel_size, strides=stride, - padding=keras_pad_type, + padding=pad_type, dilation_rate=dilation, use_bias=True, data_format=self.data_format, @@ -463,7 +463,7 @@ def __init__( keras.layers.Conv2D( filters=out_chs, kernel_size=pw_kernel_size, - padding=keras_pad_type, + padding=pad_type, use_bias=True, data_format=self.data_format, name=f"conv_pwl_expert_{i}", @@ -682,19 +682,36 @@ def call(self, inputs, training=False): resized_inputs, axis=self.channel_axis ) img = self.ffn(channel_cat_imgs, training=training) - if self.data_format == "channels_first": - img_transposed = keras.ops.transpose(img, (0, 2, 3, 1)) - else: - img_transposed = img - img_resized = keras.ops.image.resize( - img_transposed, - size=self.output_resolution, - interpolation="bilinear", - ) - if self.data_format == "channels_first": - img = keras.ops.transpose(img_resized, (0, 3, 1, 2)) - else: - img = img_resized + if ( + high_resolution[0] != self.output_resolution[0] + or high_resolution[1] != self.output_resolution[1] + ): + h_in, w_in = high_resolution + h_out, w_out = self.output_resolution + if h_in % h_out == 0 and w_in % w_out == 0: + h_stride = h_in // h_out + w_stride = w_in // w_out + img = keras.ops.nn.average_pool( + img, + pool_size=(h_stride, w_stride), + strides=(h_stride, w_stride), + padding="valid", + data_format=self.data_format, + ) + else: + if self.data_format == "channels_first": + img_transposed = keras.ops.transpose(img, (0, 2, 3, 1)) + else: + img_transposed = img + img_resized = keras.ops.image.resize( + img_transposed, + size=self.output_resolution, + interpolation="bilinear", + ) + if self.data_format == "channels_first": + img = keras.ops.transpose(img_resized, (0, 3, 1, 2)) + else: + img = img_resized img = self.norm(img, training=training) return img diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py index b21f08e432..18fef3f285 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py @@ -220,21 +220,13 @@ def __init__( super().__init__(dtype=dtype, **kwargs) self.channel_axis = channel_axis self.data_format = data_format - self.pad = None self.kernel_initializer = keras.initializers.VarianceScaling( scale=2.0, mode="fan_out", distribution="untruncated_normal" ) self.bias_initializer = "zeros" padding_mode = "valid" if pad_type.lower() == "" or pad_type.lower() == "same": - if stride > 1: - self.pad = keras.layers.ZeroPadding2D( - padding=kernel_size // 2, - data_format=self.data_format, - dtype=self.dtype_policy, - ) - else: - padding_mode = "same" + padding_mode = "same" self.conv = keras.layers.Conv2D( out_chs, @@ -279,7 +271,7 @@ def __init__( if self.apply_act: if act_layer == "gelu": self.act = keras.layers.Activation( - lambda x: keras.activations.gelu(x, approximate=True), + lambda x: keras.activations.gelu(x, approximate=False), dtype=self.dtype_policy, ) else: @@ -289,22 +281,14 @@ def __init__( ) def build(self, input_shape): - if self.pad: - self.pad.build(input_shape) - conv_input_shape = self.pad.compute_output_shape(input_shape) - else: - conv_input_shape = input_shape - - self.conv.build(conv_input_shape) - conv_output_shape = self.conv.compute_output_shape(conv_input_shape) + self.conv.build(input_shape) + conv_output_shape = self.conv.compute_output_shape(input_shape) self.norm.build(conv_output_shape) if self.apply_act: self.act.build(conv_output_shape) self.built = True def call(self, x, training=False): - if self.pad: - x = self.pad(x) x = self.conv(x) x = self.norm(x, training=training) if self.apply_act: @@ -312,11 +296,7 @@ def call(self, x, training=False): return x def compute_output_shape(self, input_shape): - if self.pad: - padded_shape = self.pad.compute_output_shape(input_shape) - return self.conv.compute_output_shape(padded_shape) - else: - return self.conv.compute_output_shape(input_shape) + return self.conv.compute_output_shape(input_shape) class SEModule(keras.layers.Layer): diff --git a/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py b/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py index 479eb60ba2..cbef6a24cf 100644 --- a/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py +++ b/tools/checkpoint_conversion/convert_mobilenetv5_checkpoints.py @@ -340,11 +340,11 @@ def _port_attn(self, attn_layer, attn_prefix): ) if len(attn_layer.key_layers) > 1: self._port_weights( - attn_layer.key_layers[1], + attn_layer.key_layers[0], f"{attn_prefix}.key.down_conv.weight", (2, 3, 0, 1), ) - key_norm_layer = attn_layer.key_layers[2] + key_norm_layer = attn_layer.key_layers[1] if isinstance(key_norm_layer, RmsNorm2d): self._port_rms_norm(key_norm_layer, f"{attn_prefix}.key.norm") else: @@ -356,11 +356,11 @@ def _port_attn(self, attn_layer, attn_prefix): ) if len(attn_layer.value_layers) > 1: self._port_weights( - attn_layer.value_layers[1], + attn_layer.value_layers[0], f"{attn_prefix}.value.down_conv.weight", (2, 3, 0, 1), ) - value_norm_layer = attn_layer.value_layers[2] + value_norm_layer = attn_layer.value_layers[1] if isinstance(value_norm_layer, RmsNorm2d): self._port_rms_norm( value_norm_layer, f"{attn_prefix}.value.norm" @@ -419,10 +419,6 @@ def validate_output(keras_model, timm_model): print("🔶 Keras label:", keras_label) modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs_pooled)) print("🔶 Modeling difference:", modeling_diff) - preprocessing_diff = np.mean( - np.abs(np.array(keras_preprocessed) - np.array(timm_preprocessed)) - ) - print("🔶 Preprocessing difference:", preprocessing_diff) def main(_): From 376e7f9300ef1b30e04895df977e3004e21849f3 Mon Sep 17 00:00:00 2001 From: harshaljanjani Date: Tue, 30 Sep 2025 16:17:24 +0530 Subject: [PATCH 5/5] chore: Address Gemini reviews --- .../mobilenetv5/mobilenetv5_attention.py | 149 +++++++++++-- .../models/mobilenetv5/mobilenetv5_blocks.py | 199 ++++++++++++++++-- .../models/mobilenetv5/mobilenetv5_builder.py | 18 +- .../models/mobilenetv5/mobilenetv5_layers.py | 33 ++- 4 files changed, 360 insertions(+), 39 deletions(-) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py index ea0e6706fc..a51dc507da 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_attention.py @@ -59,10 +59,23 @@ def __init__( **kwargs, ): super().__init__(dtype=dtype, **kwargs) - self.data_format = data_format - self.channel_axis = channel_axis - dim_out = dim_out or dim + self.dim = dim + self.dim_out_arg = dim_out self.num_heads = num_heads + self.key_dim_arg = key_dim + self.value_dim_arg = value_dim + self.query_strides_arg = query_strides + self.kv_stride = kv_stride + self.dw_kernel_size = dw_kernel_size + self.dilation = dilation + self.padding_arg = padding + self.attn_drop_rate = attn_drop + self.proj_drop_rate = proj_drop + self.norm_layer = norm_layer + self.use_bias = use_bias + self.channel_axis = channel_axis + self.data_format = data_format + self.dim_out = dim_out or dim self.key_dim = key_dim or dim // num_heads self.value_dim = value_dim or dim // num_heads self.query_strides = ( @@ -70,7 +83,6 @@ def __init__( if isinstance(query_strides, (list, tuple)) else (query_strides, query_strides) ) - self.kv_stride = kv_stride self.has_query_strides = any([s > 1 for s in self.query_strides]) self.scale = self.key_dim**-0.5 self.padding = "same" if padding == "" else "valid" @@ -215,7 +227,7 @@ def __init__( ) ) self.value_layers = value_layers - self.attn_drop = keras.layers.Dropout( + self.attn_drop_layer = keras.layers.Dropout( attn_drop, dtype=self.dtype_policy ) output_layers = [] @@ -231,7 +243,7 @@ def __init__( ) output_layers.append( keras.layers.Conv2D( - filters=dim_out, + filters=self.dim_out, kernel_size=1, use_bias=use_bias, data_format=self.data_format, @@ -283,7 +295,7 @@ def call(self, x, training=False): q = q * self.scale attn = keras.ops.matmul(q, keras.ops.transpose(k, (0, 1, 3, 2))) attn = keras.ops.softmax(attn, axis=-1) - attn = self.attn_drop(attn, training=training) + attn = self.attn_drop_layer(attn, training=training) o = keras.ops.matmul(attn, v) o = keras.ops.transpose(o, (0, 2, 1, 3)) feat_dim = self.num_heads * self.value_dim @@ -298,6 +310,39 @@ def call(self, x, training=False): x_out = layer(x_out) return x_out + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "dim_out": self.dim_out_arg, + "num_heads": self.num_heads, + "key_dim": self.key_dim_arg, + "value_dim": self.value_dim_arg, + "query_strides": self.query_strides_arg, + "kv_stride": self.kv_stride, + "dw_kernel_size": self.dw_kernel_size, + "dilation": self.dilation, + "padding": self.padding_arg, + "attn_drop": self.attn_drop_rate, + "proj_drop": self.proj_drop_rate, + "norm_layer": keras.saving.serialize_keras_object( + self.norm_layer + ), + "use_bias": self.use_bias, + "channel_axis": self.channel_axis, + "data_format": self.data_format, + } + ) + return config + + @classmethod + def from_config(cls, config): + config["norm_layer"] = keras.saving.deserialize_keras_object( + config["norm_layer"] + ) + return cls(**config) + class Attention2d(keras.layers.Layer): """Implements 2D Multi-Head Attention. @@ -333,13 +378,15 @@ def __init__( **kwargs, ): super().__init__(dtype=dtype, **kwargs) - self.data_format = data_format - self.channel_axis = channel_axis - dim_out = dim_out or dim self.dim = dim - self.dim_out = dim_out + self.dim_out_arg = dim_out self.num_heads = num_heads self.bias = bias + self.attn_drop_rate = attn_drop + self.proj_drop_rate = proj_drop + self.channel_axis = channel_axis + self.data_format = data_format + self.dim_out = dim_out or dim self.head_dim = dim // num_heads self.conv_kernel_initializer = keras.initializers.VarianceScaling( scale=2.0, mode="fan_out", distribution="untruncated_normal" @@ -355,11 +402,11 @@ def __init__( kernel_initializer=self.conv_kernel_initializer, bias_initializer=self.bias_initializer, ) - self.attn_drop = keras.layers.Dropout( + self.attn_drop_layer = keras.layers.Dropout( attn_drop, dtype=self.dtype_policy ) self.proj = keras.layers.Conv2D( - dim_out, + self.dim_out, kernel_size=1, use_bias=bias, data_format=self.data_format, @@ -368,7 +415,7 @@ def __init__( kernel_initializer=self.conv_kernel_initializer, bias_initializer=self.bias_initializer, ) - self.proj_drop = keras.layers.Dropout( + self.proj_drop_layer = keras.layers.Dropout( proj_drop, dtype=self.dtype_policy ) @@ -394,7 +441,7 @@ def call(self, x, attn_mask=None, training=False): if attn_mask is not None: attn = attn + attn_mask attn = keras.ops.softmax(attn, axis=-1) - attn = self.attn_drop(attn, training=training) + attn = self.attn_drop_layer(attn, training=training) x = keras.ops.matmul(attn, v) x = keras.ops.transpose(x, (0, 1, 3, 2)) if self.data_format == "channels_first": @@ -402,9 +449,25 @@ def call(self, x, attn_mask=None, training=False): else: x = keras.ops.reshape(x, (B, H, W, -1)) x = self.proj(x) - x = self.proj_drop(x, training=training) + x = self.proj_drop_layer(x, training=training) return x + def get_config(self): + config = super().get_config() + config.update( + { + "dim": self.dim, + "dim_out": self.dim_out_arg, + "num_heads": self.num_heads, + "bias": self.bias, + "attn_drop": self.attn_drop_rate, + "proj_drop": self.proj_drop_rate, + "channel_axis": self.channel_axis, + "data_format": self.data_format, + } + ) + return config + class MobileAttention(keras.layers.Layer): """MobileNetV5 attention block. @@ -474,8 +537,29 @@ def __init__( **kwargs, ): super().__init__(dtype=dtype, **kwargs) - self.data_format = data_format + self.in_chs = in_chs + self.out_chs = out_chs + self.stride = stride + self.dw_kernel_size = dw_kernel_size + self.dilation = dilation + self.pad_type = pad_type + self.num_heads = num_heads + self.key_dim = key_dim + self.value_dim = value_dim + self.use_multi_query = use_multi_query + self.query_strides = query_strides + self.kv_stride = kv_stride + self.cpe_dw_kernel_size = cpe_dw_kernel_size + self.noskip = noskip + self.norm_layer_name = norm_layer + self.drop_path_rate = drop_path_rate + self.attn_drop_rate = attn_drop + self.proj_drop_rate = proj_drop + self.layer_scale_init_value = layer_scale_init_value + self.use_bias = use_bias + self.use_cpe = use_cpe self.channel_axis = channel_axis + self.data_format = data_format self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip self.conv_kernel_initializer = keras.initializers.VarianceScaling( scale=2.0, mode="fan_out", distribution="untruncated_normal" @@ -585,3 +669,34 @@ def call(self, x, training=False): return self.drop_path(x_scaled, training=training) + shortcut else: return x_scaled + + def get_config(self): + config = super().get_config() + config.update( + { + "in_chs": self.in_chs, + "out_chs": self.out_chs, + "stride": self.stride, + "dw_kernel_size": self.dw_kernel_size, + "dilation": self.dilation, + "pad_type": self.pad_type, + "num_heads": self.num_heads, + "key_dim": self.key_dim, + "value_dim": self.value_dim, + "use_multi_query": self.use_multi_query, + "query_strides": self.query_strides, + "kv_stride": self.kv_stride, + "cpe_dw_kernel_size": self.cpe_dw_kernel_size, + "noskip": self.noskip, + "norm_layer": self.norm_layer_name, + "drop_path_rate": self.drop_path_rate, + "attn_drop": self.attn_drop_rate, + "proj_drop": self.proj_drop_rate, + "layer_scale_init_value": self.layer_scale_init_value, + "use_bias": self.use_bias, + "use_cpe": self.use_cpe, + "channel_axis": self.channel_axis, + "data_format": self.data_format, + } + ) + return config diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py index f026bb0564..daa667d5cd 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_blocks.py @@ -64,12 +64,25 @@ def __init__( **kwargs, ): super().__init__(dtype=dtype, **kwargs) - self.has_skip = (in_chs == out_chs and stride == 1) and not noskip self.in_chs = in_chs self.out_chs = out_chs + self.dw_kernel_size_start = dw_kernel_size_start + self.dw_kernel_size_mid = dw_kernel_size_mid + self.dw_kernel_size_end = dw_kernel_size_end + self.stride = stride + self.dilation = dilation + self.pad_type = pad_type + self.noskip = noskip + self.exp_ratio = exp_ratio + self.act_layer = act_layer + self.norm_layer = norm_layer + self.se_layer = se_layer + self.drop_path_rate = drop_path_rate + self.layer_scale_init_value = layer_scale_init_value self.data_format = data_format self.channel_axis = channel_axis - pad_type = "same" if pad_type == "" else "valid" + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + pad_type_internal = "same" if pad_type == "" else "valid" use_bias = norm_layer == "rms_norm" if dw_kernel_size_start: @@ -79,7 +92,7 @@ def __init__( stride=stride if not dw_kernel_size_mid else 1, dilation=dilation, groups=in_chs, - pad_type=pad_type, + pad_type=pad_type_internal, apply_act=False, act_layer=act_layer, norm_layer=norm_layer, @@ -95,7 +108,7 @@ def __init__( self.pw_exp = ConvNormAct( mid_chs, 1, - pad_type=pad_type, + pad_type=pad_type_internal, act_layer=act_layer, norm_layer=norm_layer, bias=use_bias, @@ -111,7 +124,7 @@ def __init__( stride=stride, dilation=dilation, groups=mid_chs, - pad_type=pad_type, + pad_type=pad_type_internal, act_layer=act_layer, norm_layer=norm_layer, bias=use_bias, @@ -137,7 +150,7 @@ def __init__( self.pw_proj = ConvNormAct( out_chs, 1, - pad_type=pad_type, + pad_type=pad_type_internal, apply_act=False, act_layer=act_layer, norm_layer=norm_layer, @@ -156,7 +169,7 @@ def __init__( else 1, dilation=dilation, groups=out_chs, - pad_type=pad_type, + pad_type=pad_type_internal, apply_act=False, act_layer=act_layer, norm_layer=norm_layer, @@ -231,6 +244,38 @@ def compute_output_shape(self, input_shape): current_shape = self.dw_end.compute_output_shape(current_shape) return current_shape + def get_config(self): + config = super().get_config() + config.update( + { + "in_chs": self.in_chs, + "out_chs": self.out_chs, + "dw_kernel_size_start": self.dw_kernel_size_start, + "dw_kernel_size_mid": self.dw_kernel_size_mid, + "dw_kernel_size_end": self.dw_kernel_size_end, + "stride": self.stride, + "dilation": self.dilation, + "pad_type": self.pad_type, + "noskip": self.noskip, + "exp_ratio": self.exp_ratio, + "act_layer": self.act_layer, + "norm_layer": self.norm_layer, + "se_layer": keras.saving.serialize_keras_object(self.se_layer), + "drop_path_rate": self.drop_path_rate, + "layer_scale_init_value": self.layer_scale_init_value, + "data_format": self.data_format, + "channel_axis": self.channel_axis, + } + ) + return config + + @classmethod + def from_config(cls, config): + config["se_layer"] = keras.saving.deserialize_keras_object( + config.pop("se_layer") + ) + return cls(**config) + class EdgeResidual(keras.layers.Layer): """Edge Residual block. @@ -284,12 +329,25 @@ def __init__( **kwargs, ): super().__init__(dtype=dtype, **kwargs) - self.has_skip = (in_chs == out_chs and stride == 1) and not noskip self.in_chs = in_chs self.out_chs = out_chs + self.exp_kernel_size = exp_kernel_size + self.stride = stride + self.dilation = dilation + self.group_size = group_size + self.pad_type = pad_type + self.force_in_chs = force_in_chs + self.noskip = noskip + self.exp_ratio = exp_ratio + self.pw_kernel_size = pw_kernel_size + self.act_layer = act_layer + self.norm_layer = norm_layer + self.se_layer = se_layer + self.drop_path_rate = drop_path_rate self.data_format = data_format self.channel_axis = channel_axis - pad_type = "same" if pad_type == "" else "valid" + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + pad_type_internal = "same" if pad_type == "" else "valid" if force_in_chs > 0: mid_chs = adjust_channels(force_in_chs * exp_ratio) else: @@ -302,7 +360,7 @@ def __init__( stride=stride, dilation=dilation, groups=groups, - pad_type=pad_type, + pad_type=pad_type_internal, norm_layer=norm_layer, act_layer=act_layer, bias=use_bias, @@ -326,7 +384,7 @@ def __init__( self.conv_pwl = ConvNormAct( out_chs, pw_kernel_size, - pad_type=pad_type, + pad_type=pad_type_internal, apply_act=False, norm_layer=norm_layer, act_layer=act_layer, @@ -358,6 +416,38 @@ def call(self, x, training=False): x = self.drop_path(x, training=training) + shortcut return x + def get_config(self): + config = super().get_config() + config.update( + { + "in_chs": self.in_chs, + "out_chs": self.out_chs, + "exp_kernel_size": self.exp_kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "group_size": self.group_size, + "pad_type": self.pad_type, + "force_in_chs": self.force_in_chs, + "noskip": self.noskip, + "exp_ratio": self.exp_ratio, + "pw_kernel_size": self.pw_kernel_size, + "act_layer": self.act_layer, + "norm_layer": self.norm_layer, + "se_layer": keras.saving.serialize_keras_object(self.se_layer), + "drop_path_rate": self.drop_path_rate, + "data_format": self.data_format, + "channel_axis": self.channel_axis, + } + ) + return config + + @classmethod + def from_config(cls, config): + config["se_layer"] = keras.saving.deserialize_keras_object( + config.pop("se_layer") + ) + return cls(**config) + class CondConvResidual(keras.layers.Layer): """Conditionally Parameterized Convolutional Residual block. @@ -408,10 +498,23 @@ def __init__( **kwargs, ): super().__init__(dtype=dtype, **kwargs) - self.has_skip = (in_chs == out_chs and stride == 1) and not noskip + self.in_chs = in_chs + self.out_chs = out_chs + self.dw_kernel_size = dw_kernel_size + self.stride = stride + self.dilation = dilation + self.pad_type = pad_type + self.noskip = noskip + self.exp_ratio = exp_ratio + self.exp_kernel_size = exp_kernel_size + self.pw_kernel_size = pw_kernel_size + self.act_layer = act_layer + self.se_layer = se_layer self.num_experts = num_experts + self.drop_path_rate = drop_path_rate self.data_format = data_format self.channel_axis = channel_axis + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip self.conv_kernel_initializer = keras.initializers.VarianceScaling( scale=2.0, mode="fan_out", distribution="untruncated_normal" ) @@ -420,7 +523,7 @@ def __init__( ) self.bias_initializer = "zeros" mid_chs = adjust_channels(in_chs * exp_ratio) - pad_type = "same" if pad_type == "" else "valid" + pad_type_internal = "same" if pad_type == "" else "valid" self.routing_fn = keras.layers.Dense( self.num_experts, dtype=self.dtype_policy, @@ -434,7 +537,7 @@ def __init__( keras.layers.Conv2D( filters=mid_chs, kernel_size=exp_kernel_size, - padding=pad_type, + padding=pad_type_internal, use_bias=True, data_format=self.data_format, name=f"conv_pw_expert_{i}", @@ -448,7 +551,7 @@ def __init__( keras.layers.DepthwiseConv2D( kernel_size=dw_kernel_size, strides=stride, - padding=pad_type, + padding=pad_type_internal, dilation_rate=dilation, use_bias=True, data_format=self.data_format, @@ -463,7 +566,7 @@ def __init__( keras.layers.Conv2D( filters=out_chs, kernel_size=pw_kernel_size, - padding=pad_type, + padding=pad_type_internal, use_bias=True, data_format=self.data_format, name=f"conv_pwl_expert_{i}", @@ -562,6 +665,37 @@ def call(self, x, training=False): x = self.drop_path(x, training=training) + shortcut return x + def get_config(self): + config = super().get_config() + config.update( + { + "in_chs": self.in_chs, + "out_chs": self.out_chs, + "dw_kernel_size": self.dw_kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "pad_type": self.pad_type, + "noskip": self.noskip, + "exp_ratio": self.exp_ratio, + "exp_kernel_size": self.exp_kernel_size, + "pw_kernel_size": self.pw_kernel_size, + "act_layer": self.act_layer, + "se_layer": keras.saving.serialize_keras_object(self.se_layer), + "num_experts": self.num_experts, + "drop_path_rate": self.drop_path_rate, + "data_format": self.data_format, + "channel_axis": self.channel_axis, + } + ) + return config + + @classmethod + def from_config(cls, config): + config["se_layer"] = keras.saving.deserialize_keras_object( + config.pop("se_layer") + ) + return cls(**config) + class MobileNetV5MultiScaleFusionAdapter(keras.layers.Layer): """Multi-Scale Fusion Adapter for MobileNetV5. @@ -606,15 +740,23 @@ def __init__( **kwargs, ): super().__init__(dtype=dtype, **kwargs) - self.in_channels = sum(in_chs) - self.out_channels = out_chs + self.in_chs = in_chs + self.out_chs = out_chs + self.output_resolution_arg = output_resolution + self.expansion_ratio = expansion_ratio + self.interpolation_mode = interpolation_mode + self.layer_scale_init_value = layer_scale_init_value + self.noskip = noskip + self.act_layer = act_layer + self.norm_layer_name = norm_layer self.data_format = data_format self.channel_axis = channel_axis + self.in_channels = sum(in_chs) + self.out_channels = out_chs if isinstance(output_resolution, int): self.output_resolution = (output_resolution, output_resolution) else: self.output_resolution = output_resolution - self.interpolation_mode = interpolation_mode self.ffn = UniversalInvertedResidual( in_chs=self.in_channels, out_chs=self.out_channels, @@ -731,3 +873,22 @@ def compute_output_shape(self, input_shape): self.output_resolution[1], self.out_channels, ) + + def get_config(self): + config = super().get_config() + config.update( + { + "in_chs": self.in_chs, + "out_chs": self.out_chs, + "output_resolution": self.output_resolution_arg, + "expansion_ratio": self.expansion_ratio, + "interpolation_mode": self.interpolation_mode, + "layer_scale_init_value": self.layer_scale_init_value, + "noskip": self.noskip, + "act_layer": self.act_layer, + "norm_layer": self.norm_layer_name, + "data_format": self.data_format, + "channel_axis": self.channel_axis, + } + ) + return config diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py index d7bbe6f163..aa39ec653b 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_builder.py @@ -51,7 +51,16 @@ def decode_block_str(block_str): act_layer=act_layer, ) - if block_type == "uir": + if block_type == "ir": + block_args.update( + dict( + dw_kernel_size=parse_ksize(options["k"]), + exp_ratio=float(options["e"]), + se_ratio=float(options.get("se", 0.0)), + noskip=skip is False, + ) + ) + elif block_type == "uir": start_kernel_size = parse_ksize(options.get("a", "0")) end_kernel_size = parse_ksize(options.get("p", "0")) block_args.update( @@ -204,6 +213,12 @@ def _make_block(self, ba, block_idx, block_count): ba["se_layer"] = None ba.pop("aa_layer", None) if bt == "ir": + padding = 0 + if ba["pad_type"].lower() in ("", "same"): + kernel_size = ba["dw_kernel_size"] + if isinstance(kernel_size, (list, tuple)): + kernel_size = kernel_size[0] + padding = (kernel_size - 1) // 2 block = ( CondConvResidual(**ba) if ba.get("num_experts", 0) > 0 @@ -213,6 +228,7 @@ def _make_block(self, ba, block_idx, block_count): filters=ba["out_chs"], kernel_size=ba["dw_kernel_size"], stride=ba["stride"], + padding=padding, squeeze_excite_ratio=ba.pop("se_ratio", None), activation=ba["act_layer"], ) diff --git a/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py b/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py index 18fef3f285..bcf11df23d 100644 --- a/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py +++ b/keras_hub/src/models/mobilenetv5/mobilenetv5_layers.py @@ -218,8 +218,18 @@ def __init__( **kwargs, ): super().__init__(dtype=dtype, **kwargs) - self.channel_axis = channel_axis + self.out_chs = out_chs + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.groups = groups + self.bias = bias + self.pad_type = pad_type + self.apply_act = apply_act + self.act_layer = act_layer + self.norm_layer = norm_layer self.data_format = data_format + self.channel_axis = channel_axis self.kernel_initializer = keras.initializers.VarianceScaling( scale=2.0, mode="fan_out", distribution="untruncated_normal" ) @@ -267,7 +277,6 @@ def __init__( dtype=self.dtype_policy, ) - self.apply_act = apply_act if self.apply_act: if act_layer == "gelu": self.act = keras.layers.Activation( @@ -298,6 +307,26 @@ def call(self, x, training=False): def compute_output_shape(self, input_shape): return self.conv.compute_output_shape(input_shape) + def get_config(self): + config = super().get_config() + config.update( + { + "out_chs": self.out_chs, + "kernel_size": self.kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "groups": self.groups, + "bias": self.bias, + "pad_type": self.pad_type, + "apply_act": self.apply_act, + "act_layer": self.act_layer, + "norm_layer": self.norm_layer, + "data_format": self.data_format, + "channel_axis": self.channel_axis, + } + ) + return config + class SEModule(keras.layers.Layer): """Implements the Squeeze-and-Excitation (SE) module.