From 7d4ae11fb5557f07470703aebf2f62e8d17816ae Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 28 Aug 2025 23:10:24 +0800 Subject: [PATCH 1/9] Add `DepthAnythingBackbone`. --- keras_hub/api/layers/__init__.py | 3 + keras_hub/api/models/__init__.py | 15 + .../src/models/depth_anything/__init__.py | 9 + .../depth_anything/depth_anything_backbone.py | 152 +++++ .../depth_anything_backbone_test.py | 74 +++ .../depth_anything_depth_estimator.py | 70 ++ ...h_anything_depth_estimator_preprocessor.py | 16 + .../depth_anything_depth_estimator_test.py | 99 +++ .../depth_anything_image_converter.py | 10 + .../depth_anything/depth_anything_layers.py | 596 ++++++++++++++++++ .../depth_anything/depth_anything_loss.py | 85 +++ .../depth_anything/depth_anything_presets.py | 4 + .../src/models/depth_anything/interpolate.py | 62 ++ keras_hub/src/models/depth_estimator.py | 238 +++++++ .../models/depth_estimator_preprocessor.py | 78 +++ .../src/models/dinov2/dinov2_backbone.py | 18 +- keras_hub/src/models/dinov2/dinov2_layers.py | 11 +- .../src/utils/transformers/convert_dinov2.py | 1 + 18 files changed, 1535 insertions(+), 6 deletions(-) create mode 100644 keras_hub/src/models/depth_anything/__init__.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_backbone.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_backbone_test.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_image_converter.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_layers.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_loss.py create mode 100644 keras_hub/src/models/depth_anything/depth_anything_presets.py create mode 100644 keras_hub/src/models/depth_anything/interpolate.py create mode 100644 keras_hub/src/models/depth_estimator.py create mode 100644 keras_hub/src/models/depth_estimator_preprocessor.py diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index f90c214d6b..4550cf8689 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -87,6 +87,9 @@ from keras_hub.src.models.densenet.densenet_image_converter import ( DenseNetImageConverter as DenseNetImageConverter, ) +from keras_hub.src.models.depth_anything.depth_anything_image_converter import ( + DepthAnythingImageConverter as DepthAnythingImageConverter, +) from keras_hub.src.models.dinov2.dinov2_image_converter import ( DINOV2ImageConverter as DINOV2ImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index fe220e2d43..62b4a2cc98 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -166,6 +166,21 @@ from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import ( DenseNetImageClassifierPreprocessor as DenseNetImageClassifierPreprocessor, ) +from keras_hub.src.models.depth_anything.depth_anything_backbone import ( + DepthAnythingBackbone as DepthAnythingBackbone, +) +from keras_hub.src.models.depth_anything.depth_anything_depth_estimator import ( + DepthAnythingDepthEstimator as DepthAnythingDepthEstimator, +) +from keras_hub.src.models.depth_anything.depth_anything_depth_estimator_preprocessor import ( + DepthAnythingDepthEstimatorPreprocessor as DepthAnythingDepthEstimatorPreprocessor, +) +from keras_hub.src.models.depth_estimator import ( + DepthEstimator as DepthEstimator, +) +from keras_hub.src.models.depth_estimator_preprocessor import ( + DepthEstimatorPreprocessor as DepthEstimatorPreprocessor, +) from keras_hub.src.models.dinov2.dinov2_backbone import ( DINOV2Backbone as DINOV2Backbone, ) diff --git a/keras_hub/src/models/depth_anything/__init__.py b/keras_hub/src/models/depth_anything/__init__.py new file mode 100644 index 0000000000..cde1e4c7f5 --- /dev/null +++ b/keras_hub/src/models/depth_anything/__init__.py @@ -0,0 +1,9 @@ +from keras_hub.src.models.depth_anything.depth_anything_backbone import ( + DepthAnythingBackbone, +) +from keras_hub.src.models.depth_anything.depth_anything_presets import ( + backbone_presets, +) +from keras_hub.src.utils.preset_utils import register_presets + +register_presets(backbone_presets, DepthAnythingBackbone) diff --git a/keras_hub/src/models/depth_anything/depth_anything_backbone.py b/keras_hub/src/models/depth_anything/depth_anything_backbone.py new file mode 100644 index 0000000000..6198544375 --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_backbone.py @@ -0,0 +1,152 @@ +import keras +from keras import layers + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.depth_anything.depth_anything_layers import ( + DepthAnythingDepthEstimationHead, +) +from keras_hub.src.models.depth_anything.depth_anything_layers import ( + DepthAnythingNeck, +) +from keras_hub.src.models.dinov2 import DINOV2Backbone +from keras_hub.src.utils.keras_utils import standardize_data_format + + +@keras_hub_export("keras_hub.models.DepthAnythingBackbone") +class DepthAnythingBackbone(Backbone): + def __init__( + self, + image_encoder, + patch_size, + backbone_hidden_dim, + reassemble_factors, + neck_hidden_dims, + fusion_hidden_dim, + head_hidden_dim, + head_in_index, + image_shape=(224, 224, 3), + feature_keys=None, + data_format=None, + dtype=None, + **kwargs, + ): + if not isinstance(image_encoder, DINOV2Backbone): + raise ValueError( + "`image_encoder` must be a `DINOV2Backbone`. " + f"Received image_encoder={image_encoder} " + f"(of type {type(image_encoder)})." + ) + if feature_keys is not None: + feature_keys = [str(key) for key in feature_keys] + for key in feature_keys: + if key not in image_encoder.pyramid_outputs: + raise ValueError( + "All `feature_keys` must be in " + "`image_encoder.pyramid_outputs`. " + f"Received feature_keys={feature_keys}, but " + "`image_encoder.pyramid_outputs` contains " + f"{list(image_encoder.pyramid_outputs.keys())}." + ) + data_format = standardize_data_format(data_format) + if data_format == "channels_last": + image_size = (image_shape[0], image_shape[1]) + else: + image_size = (image_shape[1], image_shape[2]) + + # === Layers === + if feature_keys is None: + pyramid_outputs = image_encoder.pyramid_outputs + else: + pyramid_outputs = { + key: value + for key, value in image_encoder.pyramid_outputs.items() + if key in feature_keys + } + self.feature_extractor = keras.Model( + inputs=image_encoder.inputs, + outputs=pyramid_outputs, + ) + self.feature_extractor.dtype_policy = image_encoder.dtype_policy + self.neck = DepthAnythingNeck( + patch_size=patch_size, + image_size=image_size, + backbone_hidden_dim=backbone_hidden_dim, + neck_hidden_dims=neck_hidden_dims, + reassemble_factors=reassemble_factors, + fusion_hidden_dim=fusion_hidden_dim, + num_cls_tokens=1, + num_register_tokens=image_encoder.num_register_tokens, + data_format=data_format, + dtype=dtype, + name="neck", + ) + self.head = DepthAnythingDepthEstimationHead( + patch_size=patch_size, + patch_height=image_size[0] // patch_size, + patch_width=image_size[1] // patch_size, + fusion_hidden_dim=fusion_hidden_dim, + head_hidden_dim=head_hidden_dim, + head_in_index=head_in_index, + data_format=data_format, + dtype=dtype, + name="head", + ) + + # === Functional Model === + image_input = layers.Input(shape=image_shape, name="images") + features = self.feature_extractor(image_input) + features = self.neck(list(features.values())) + depth_output = self.head(features) + super().__init__( + inputs=image_input, + outputs=depth_output, + dtype=dtype, + **kwargs, + ) + + # === Config === + self.image_encoder = image_encoder + self.patch_size = patch_size + self.backbone_hidden_dim = backbone_hidden_dim + self.reassemble_factors = reassemble_factors + self.neck_hidden_dims = neck_hidden_dims + self.fusion_hidden_dim = fusion_hidden_dim + self.head_hidden_dim = head_hidden_dim + self.head_in_index = head_in_index + self.image_shape = image_shape + self.feature_keys = feature_keys + + def get_config(self): + config = super().get_config() + config.update( + { + "image_encoder": layers.serialize(self.image_encoder), + "patch_size": self.patch_size, + "backbone_hidden_dim": self.backbone_hidden_dim, + "reassemble_factors": self.reassemble_factors, + "neck_hidden_dims": self.neck_hidden_dims, + "fusion_hidden_dim": self.fusion_hidden_dim, + "head_hidden_dim": self.head_hidden_dim, + "head_in_index": self.head_in_index, + "image_shape": self.image_shape, + "feature_keys": self.feature_keys, + } + ) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() + + # Propagate `dtype` to `image_encoder` if needed. + if "dtype" in config and config["dtype"] is not None: + dtype_config = config["dtype"] + if "dtype" not in config["image_encoder"]["config"]: + config["image_encoder"]["config"]["dtype"] = dtype_config + + # We expect submodels to be instantiated. + config["image_encoder"] = layers.deserialize( + config["image_encoder"], custom_objects=custom_objects + ) + return cls(**config) diff --git a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py new file mode 100644 index 0000000000..21d53c94de --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py @@ -0,0 +1,74 @@ +import pytest +from keras import ops + +from keras_hub.src.models.depth_anything.depth_anything_backbone import ( + DepthAnythingBackbone, +) +from keras_hub.src.models.dinov2.dinov2_backbone import DINOV2Backbone +from keras_hub.src.tests.test_case import TestCase + + +class DepthAnythingBackboneTest(TestCase): + def setUp(self): + image_encoder = DINOV2Backbone( + 14, + 4, + 16, + 2, + 16 * 4, + 1.0, + 0, + image_shape=(70, 70, 3), + apply_layernorm=True, + ) + self.init_kwargs = { + "image_encoder": image_encoder, + "patch_size": image_encoder.patch_size, + "backbone_hidden_dim": image_encoder.hidden_dim, + "reassemble_factors": [4, 2, 1, 0.5], + "neck_hidden_dims": [16, 32, 64, 128], + "fusion_hidden_dim": 128, + "head_hidden_dim": 16, + "head_in_index": -1, + "image_shape": (70, 70, 3), + "feature_keys": ["Stage1", "Stage2", "Stage3", "Stage4"], + } + self.input_data = ops.ones((2, 70, 70, 3)) + + def test_backbone_basics(self): + self.run_backbone_test( + cls=DepthAnythingBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 70, 70, 1), + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DepthAnythingBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_smallest_preset(self): + self.skipTest("Presets are not uploaded yet.") + self.run_preset_test( + cls=DepthAnythingBackbone, + preset="depth_anything_v2_small", + input_data=self.input_data, + expected_output_shape=(2, 70, 70, 1), + ) + + @pytest.mark.kaggle_key_required + @pytest.mark.extra_large + def test_all_presets(self): + self.skipTest("Presets are not uploaded yet.") + for preset in DepthAnythingBackbone.presets: + self.run_preset_test( + cls=DepthAnythingBackbone, + preset=preset, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py new file mode 100644 index 0000000000..104218597a --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator.py @@ -0,0 +1,70 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.depth_anything.depth_anything_backbone import ( + DepthAnythingBackbone, +) +from keras_hub.src.models.depth_anything.depth_anything_depth_estimator_preprocessor import ( # noqa: E501 + DepthAnythingDepthEstimatorPreprocessor, +) +from keras_hub.src.models.depth_anything.depth_anything_loss import ( + DepthAnythingLoss, +) +from keras_hub.src.models.depth_estimator import DepthEstimator + + +@keras_hub_export("keras_hub.models.DepthAnythingDepthEstimator") +class DepthAnythingDepthEstimator(DepthEstimator): + backbone_cls = DepthAnythingBackbone + preprocessor_cls = DepthAnythingDepthEstimatorPreprocessor + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `DepthEstimator` task for training. + + The `DepthEstimator` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a `DepthAnythingLoss` loss will be + applied for the depth estimation task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a dict of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.RootMeanSquaredError` will be applied to + track the accuracy of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.AdamW(5e-5) + if loss == "auto": + loss = { + "depths": DepthAnythingLoss( + min_depth=self.min_depth, max_depth=self.max_depth + ) + } + if metrics == "auto": + metrics = {"depths": keras.metrics.RootMeanSquaredError()} + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py new file mode 100644 index 0000000000..bb36599678 --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_preprocessor.py @@ -0,0 +1,16 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.depth_anything.depth_anything_backbone import ( + DepthAnythingBackbone, +) +from keras_hub.src.models.depth_anything.depth_anything_image_converter import ( + DepthAnythingImageConverter, +) +from keras_hub.src.models.depth_estimator_preprocessor import ( + DepthEstimatorPreprocessor, +) + + +@keras_hub_export("keras_hub.models.DepthAnythingDepthEstimatorPreprocessor") +class DepthAnythingDepthEstimatorPreprocessor(DepthEstimatorPreprocessor): + backbone_cls = DepthAnythingBackbone + image_converter_cls = DepthAnythingImageConverter diff --git a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py new file mode 100644 index 0000000000..1c8dcb1dd7 --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py @@ -0,0 +1,99 @@ +import pytest +from keras import ops + +from keras_hub.src.models.depth_anything.depth_anything_backbone import ( + DepthAnythingBackbone, +) +from keras_hub.src.models.depth_anything.depth_anything_depth_estimator import ( + DepthAnythingDepthEstimator, +) +from keras_hub.src.models.depth_anything.depth_anything_depth_estimator_preprocessor import ( # noqa: E501 + DepthAnythingDepthEstimatorPreprocessor, +) +from keras_hub.src.models.depth_anything.depth_anything_image_converter import ( + DepthAnythingImageConverter, +) +from keras_hub.src.models.dinov2.dinov2_backbone import DINOV2Backbone +from keras_hub.src.tests.test_case import TestCase + + +class DepthAnythingDepthEstimatorTest(TestCase): + def setUp(self): + image_encoder = DINOV2Backbone( + 14, + 4, + 16, + 2, + 16 * 4, + 1.0, + 0, + image_shape=(70, 70, 3), + apply_layernorm=True, + ) + self.images = ops.ones((2, 70, 70, 3)) + self.depths = ops.zeros((2, 70, 70, 1)) + self.image_converter = DepthAnythingImageConverter(image_size=(70, 70)) + self.preprocessor = DepthAnythingDepthEstimatorPreprocessor( + self.image_converter + ) + self.backbone = DepthAnythingBackbone( + image_encoder=image_encoder, + patch_size=image_encoder.patch_size, + backbone_hidden_dim=image_encoder.hidden_dim, + reassemble_factors=[4, 2, 1, 0.5], + neck_hidden_dims=[16, 32, 64, 128], + fusion_hidden_dim=128, + head_hidden_dim=16, + head_in_index=-1, + image_shape=(70, 70, 3), + feature_keys=["Stage1", "Stage2", "Stage3", "Stage4"], + ) + self.init_kwargs = { + "backbone": self.backbone, + "depth_estimation_type": "metric", + "max_depth": 10.0, + "preprocessor": self.preprocessor, + } + self.train_data = (self.images, self.depths) + + def test_depth_estimator_basics(self): + self.run_task_test( + cls=DepthAnythingDepthEstimator, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape={"depths": (2, 70, 70, 1)}, + ) + + @pytest.mark.large + def test_smallest_preset(self): + self.skipTest("Presets are not uploaded yet.") + image_batch = self.load_test_image(target_size=518)[None, ...] / 255.0 + self.run_preset_test( + cls=DepthAnythingDepthEstimator, + preset="depth_anything_v2_small", + input_data=image_batch, + expected_output_shape={"depths": (1, 518, 518, 1)}, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=DepthAnythingDepthEstimator, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) + + @pytest.mark.extra_large + def test_all_presets(self): + images = ops.ones((2, 518, 518, 3)) + for preset in DepthAnythingDepthEstimator.presets: + self.run_preset_test( + cls=DepthAnythingDepthEstimator, + preset=preset, + init_kwargs={ + "depth_estimation_type": "relative", + "max_depth": None, + }, + input_data=images, + expected_output_shape={"depths": (2, 518, 518, 1)}, + ) diff --git a/keras_hub/src/models/depth_anything/depth_anything_image_converter.py b/keras_hub/src/models/depth_anything/depth_anything_image_converter.py new file mode 100644 index 0000000000..0f4efea18b --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_image_converter.py @@ -0,0 +1,10 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.depth_anything.depth_anything_backbone import ( + DepthAnythingBackbone, +) + + +@keras_hub_export("keras_hub.layers.DepthAnythingImageConverter") +class DepthAnythingImageConverter(ImageConverter): + backbone_cls = DepthAnythingBackbone diff --git a/keras_hub/src/models/depth_anything/depth_anything_layers.py b/keras_hub/src/models/depth_anything/depth_anything_layers.py new file mode 100644 index 0000000000..cec861153a --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_layers.py @@ -0,0 +1,596 @@ +from keras import layers +from keras import ops + +from keras_hub.src.models.depth_anything.interpolate import interpolate +from keras_hub.src.utils.keras_utils import standardize_data_format + + +class DepthAnythingTokenToImage(layers.Layer): + def __init__( + self, + hidden_dim, + patch_height, + patch_width, + num_cls_tokens=1, + num_register_tokens=0, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.patch_height = int(patch_height) + self.patch_width = int(patch_width) + self.num_cls_tokens = int(num_cls_tokens) + self.num_register_tokens = int(num_register_tokens) + self.data_format = standardize_data_format(data_format) + # Always use channels_last for reshaping first. + self.target_shape = ( + self.patch_height, + self.patch_width, + self.hidden_dim, + ) + + def call(self, inputs): + # Remove the cls token. + x = inputs[:, self.num_cls_tokens + self.num_register_tokens :, ...] + + x = ops.reshape(x, (ops.shape(x)[0],) + self.target_shape) + if self.data_format == "channels_first": + x = ops.transpose(x, (0, 3, 1, 2)) + return x + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "patch_height": self.patch_height, + "patch_width": self.patch_width, + "num_cls_tokens": self.num_cls_tokens, + "num_register_tokens": self.num_register_tokens, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = [input_shape[0], *self.target_shape] + if self.data_format == "channels_first": + output_shape = [ + output_shape[0], + output_shape[3], + output_shape[1], + output_shape[2], + ] + return output_shape + + +class DepthAnythingReassembleLayer(layers.Layer): + def __init__(self, hidden_dim, factor, data_format=None, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.factor = float(factor) + self.data_format = standardize_data_format(data_format) + + self.projection = layers.Conv2D( + filters=self.hidden_dim, + kernel_size=1, + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="projection", + ) + if self.factor > 1: + self.padding = layers.Identity( + dtype=self.dtype_policy, name="padding" + ) + self.resize = layers.Conv2DTranspose( + filters=self.hidden_dim, + kernel_size=int(self.factor), + strides=int(self.factor), + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="resize", + ) + elif self.factor == 1: + self.padding = layers.Identity( + dtype=self.dtype_policy, name="padding" + ) + self.resize = layers.Identity( + dtype=self.dtype_policy, name="resize" + ) + elif self.factor < 1: + self.padding = layers.ZeroPadding2D( + padding=(1, 1), + data_format=self.data_format, + dtype=self.dtype_policy, + name="padding", + ) + self.resize = layers.Conv2D( + filters=self.hidden_dim, + kernel_size=3, + strides=int(1 / self.factor), + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="resize", + ) + + def build(self, inputs_shape): + self.projection.build(inputs_shape) + inputs_shape = self.projection.compute_output_shape(inputs_shape) + self.padding.build(inputs_shape) + inputs_shape = self.padding.compute_output_shape(inputs_shape) + self.resize.build(inputs_shape) + + def call(self, inputs, training=None): + x = self.projection(inputs, training=training) + x = self.padding(x, training=training) + return self.resize(x, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "factor": self.factor, + } + ) + return config + + def compute_output_shape(self, input_shape): + output_shape = list(input_shape) + if self.data_format == "channels_first": + output_shape[1] = self.hidden_dim + output_shape[2] = int(output_shape[2] * self.factor) + output_shape[3] = int(output_shape[3] * self.factor) + else: + output_shape[1] = int(output_shape[1] * self.factor) + output_shape[2] = int(output_shape[2] * self.factor) + output_shape[3] = self.hidden_dim + return output_shape + + +class DepthAnythingPreActResidualLayer(layers.Layer): + def __init__(self, hidden_dim, data_format=None, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.data_format = standardize_data_format(data_format) + + self.activation1 = layers.ReLU( + dtype=self.dtype_policy, name="activation1" + ) + self.padding1 = layers.ZeroPadding2D( + padding=(1, 1), + data_format=self.data_format, + dtype=self.dtype_policy, + name="padding1", + ) + self.convolution1 = layers.Conv2D( + filters=self.hidden_dim, + kernel_size=3, + strides=1, + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="convolution1", + ) + self.activation2 = layers.ReLU( + dtype=self.dtype_policy, name="activation2" + ) + self.padding2 = layers.ZeroPadding2D( + padding=(1, 1), + data_format=self.data_format, + dtype=self.dtype_policy, + name="padding2", + ) + self.convolution2 = layers.Conv2D( + filters=self.hidden_dim, + kernel_size=3, + strides=1, + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="convolution2", + ) + + def build(self, inputs_shape): + self.activation1.build(inputs_shape) + self.padding1.build(inputs_shape) + inputs_shape = self.padding1.compute_output_shape(inputs_shape) + self.convolution1.build(inputs_shape) + inputs_shape = self.convolution1.compute_output_shape(inputs_shape) + self.activation2.build(inputs_shape) + self.padding2.build(inputs_shape) + inputs_shape = self.padding2.compute_output_shape(inputs_shape) + self.convolution2.build(inputs_shape) + + def call(self, inputs, training=None): + residual = inputs + x = self.activation1(inputs, training=training) + x = self.padding1(x, training=training) + x = self.convolution1(x, training=training) + x = self.activation2(x, training=training) + x = self.padding2(x, training=training) + x = self.convolution2(x, training=training) + return ops.add(x, residual) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape + + +class DepthAnythingFeatureFusionLayer(layers.Layer): + def __init__(self, hidden_dim, size, data_format=None, **kwargs): + super().__init__(**kwargs) + self.hidden_dim = int(hidden_dim) + self.size = tuple(int(s) for s in size) + self.data_format = standardize_data_format(data_format) + + self.residual_layer1 = DepthAnythingPreActResidualLayer( + hidden_dim=self.hidden_dim, + data_format=self.data_format, + dtype=self.dtype_policy, + name="residual_layer1", + ) + self.residual_layer2 = DepthAnythingPreActResidualLayer( + hidden_dim=self.hidden_dim, + data_format=self.data_format, + dtype=self.dtype_policy, + name="residual_layer2", + ) + self.projection = layers.Conv2D( + filters=self.hidden_dim, + kernel_size=1, + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="projection", + ) + + def build(self, inputs_shape): + self.residual_layer1.build(inputs_shape) + self.residual_layer2.build(inputs_shape) + inputs_shape = list(inputs_shape) + if self.data_format == "channels_last": + inputs_shape[1] = self.size[0] + inputs_shape[2] = self.size[1] + else: + inputs_shape[2] = self.size[0] + inputs_shape[3] = self.size[1] + self.projection.build(inputs_shape) + + def call(self, inputs, residual=None, training=None): + if residual is not None: + inputs = ops.add( + inputs, self.residual_layer1(residual, training=training) + ) + + x = self.residual_layer2(inputs, training=training) + x = interpolate(x, size=self.size, data_format=self.data_format) + return self.projection(x, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "hidden_dim": self.hidden_dim, + "size": self.size, + } + ) + return config + + def compute_output_shape(self, input_shape): + input_shape = self.residual_layer2.compute_output_shape(input_shape) + input_shape = list(input_shape) + if self.data_format == "channels_last": + input_shape[1] = self.size[0] + input_shape[2] = self.size[1] + else: + input_shape[2] = self.size[0] + input_shape[3] = self.size[1] + return self.projection.compute_output_shape(input_shape) + + +class DepthAnythingNeck(layers.Layer): + def __init__( + self, + patch_size, + image_size, + backbone_hidden_dim, + neck_hidden_dims, + reassemble_factors, + fusion_hidden_dim, + num_cls_tokens=1, + num_register_tokens=0, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = int(patch_size) + self.image_size = (int(image_size[0]), int(image_size[1])) + self.backbone_hidden_dim = int(backbone_hidden_dim) + self.neck_hidden_dims = tuple(int(d) for d in neck_hidden_dims) + self.reassemble_factors = tuple(float(f) for f in reassemble_factors) + self.fusion_hidden_dim = int(fusion_hidden_dim) + self.num_cls_tokens = int(num_cls_tokens) + self.num_register_tokens = int(num_register_tokens) + self.data_format = standardize_data_format(data_format) + if len(self.neck_hidden_dims) != len(self.reassemble_factors): + raise ValueError( + "`DepthAnythingNeck` expects the length of `neck_hidden_dims` " + "and `reassemble_factors` to be the same. " + f"Received: neck_hidden_dims={neck_hidden_dims}, " + f"reassemble_factors={reassemble_factors}" + ) + + # Calculate the patch sizes for token to image layers. + patch_height = self.image_size[0] // self.patch_size + patch_width = self.image_size[1] // self.patch_size + # Calculate the sizes for fusion layers. + fusion_sizes = [ + (int(patch_height * factor), int(patch_width * factor)) + for factor in reversed(self.reassemble_factors[:-1]) + ] + fusion_sizes = fusion_sizes + [ + (fusion_sizes[-1][0] * 2, fusion_sizes[-1][1] * 2) + ] + + self.token_to_images = [ + DepthAnythingTokenToImage( + hidden_dim=backbone_hidden_dim, + patch_height=patch_height, + patch_width=patch_width, + num_cls_tokens=num_cls_tokens, + num_register_tokens=num_register_tokens, + data_format=self.data_format, + dtype=self.dtype_policy, + name=f"token_to_images_{i}", + ) + for i in range(len(self.neck_hidden_dims)) + ] + self.reassemble_stage = [ + DepthAnythingReassembleLayer( + hidden_dim=hidden_dim, + factor=factor, + data_format=self.data_format, + dtype=self.dtype_policy, + name=f"reassemble_stage_{i}", + ) + for i, (hidden_dim, factor) in enumerate( + zip(self.neck_hidden_dims, self.reassemble_factors) + ) + ] + self.paddings = [ + layers.ZeroPadding2D( + padding=(1, 1), + data_format=self.data_format, + dtype=self.dtype_policy, + name=f"paddings_{i}", + ) + for i in range(len(self.neck_hidden_dims)) + ] + self.convs = [ + layers.Conv2D( + filters=self.fusion_hidden_dim, + kernel_size=3, + data_format=self.data_format, + use_bias=False, + dtype=self.dtype_policy, + name=f"convs_{i}", + ) + for i in range(len(self.neck_hidden_dims)) + ] + self.fusion_stage = [ + DepthAnythingFeatureFusionLayer( + hidden_dim=self.fusion_hidden_dim, + size=size, + data_format=self.data_format, + dtype=self.dtype_policy, + name=f"fusion_stage_{i}", + ) + for i, size in enumerate(fusion_sizes) + ] + + def build(self, inputs_shape): + outputs_shape = [] + # Reassemble stage. + for i, shape in enumerate(inputs_shape): + self.token_to_images[i].build(shape) + shape = self.token_to_images[i].compute_output_shape(shape) + self.reassemble_stage[i].build(shape) + shape = self.reassemble_stage[i].compute_output_shape(shape) + outputs_shape.append(shape) + # Convs. + for i, shape in enumerate(outputs_shape): + self.convs[i].build(shape) + shape = self.convs[i].compute_output_shape(shape) + outputs_shape[i] = shape + # Fusion stage. + for i, shape in enumerate(reversed(outputs_shape)): + self.fusion_stage[i].build(shape) + + def call(self, inputs, training=None): + # Reassemble stage. + xs = [ + self.reassemble_stage[i]( + self.token_to_images[i](x), training=training + ) + for i, x in enumerate(inputs) + ] + # Convs. + xs = [ + self.convs[i](self.paddings[i](x), training=training) + for i, x in enumerate(xs) + ] + # Fusion stage. + fused_xs = [] + fused_x = None + for i, x in enumerate(reversed(xs)): + if fused_x is None: + fused_x = self.fusion_stage[i]( + x, residual=None, training=training + ) + else: + fused_x = self.fusion_stage[i]( + fused_x, residual=x, training=training + ) + fused_xs.append(fused_x) + return fused_xs + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "image_size": self.image_size, + "backbone_hidden_dim": self.backbone_hidden_dim, + "neck_hidden_dims": self.neck_hidden_dims, + "reassemble_factors": self.reassemble_factors, + "fusion_hidden_dim": self.fusion_hidden_dim, + "num_cls_tokens": self.num_cls_tokens, + "num_register_tokens": self.num_register_tokens, + } + ) + return config + + +class DepthAnythingDepthEstimationHead(layers.Layer): + def __init__( + self, + patch_size, + patch_height, + patch_width, + fusion_hidden_dim, + head_hidden_dim, + head_in_index, + data_format=None, + **kwargs, + ): + super().__init__(**kwargs) + self.patch_size = int(patch_size) + self.patch_height = int(patch_height) + self.patch_width = int(patch_width) + self.fusion_hidden_dim = int(fusion_hidden_dim) + self.head_hidden_dim = int(head_hidden_dim) + self.head_in_index = int(head_in_index) + self.data_format = standardize_data_format(data_format) + + # Calculate the interpolate size. + self.interpolate_size = ( + int(self.patch_height * self.patch_size), + int(self.patch_width * self.patch_size), + ) + + self.padding1 = layers.ZeroPadding2D( + padding=(1, 1), + data_format=self.data_format, + dtype=self.dtype_policy, + name="padding1", + ) + self.conv1 = layers.Conv2D( + filters=self.fusion_hidden_dim // 2, + kernel_size=3, + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="conv1", + ) + self.padding2 = layers.ZeroPadding2D( + padding=(1, 1), + data_format=self.data_format, + dtype=self.dtype_policy, + name="padding2", + ) + self.conv2 = layers.Conv2D( + filters=self.head_hidden_dim, + kernel_size=3, + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="conv2", + ) + self.activation1 = layers.ReLU( + dtype=self.dtype_policy, name="activation1" + ) + self.conv3 = layers.Conv2D( + filters=1, + kernel_size=1, + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name="conv3", + ) + + def build(self, inputs_shape): + inputs_shape = inputs_shape[self.head_in_index] + self.padding1.build(inputs_shape) + inputs_shape = self.padding1.compute_output_shape(inputs_shape) + self.conv1.build(inputs_shape) + inputs_shape = self.conv1.compute_output_shape(inputs_shape) + inputs_shape = list(inputs_shape) + if self.data_format == "channels_last": + inputs_shape[1] = self.interpolate_size[0] + inputs_shape[2] = self.interpolate_size[1] + else: + inputs_shape[2] = self.interpolate_size[0] + inputs_shape[3] = self.interpolate_size[1] + self.padding2.build(inputs_shape) + inputs_shape = self.padding2.compute_output_shape(inputs_shape) + self.conv2.build(inputs_shape) + inputs_shape = self.conv2.compute_output_shape(inputs_shape) + self.activation1.build(inputs_shape) + self.conv3.build(inputs_shape) + inputs_shape = self.conv3.compute_output_shape(inputs_shape) + + def call(self, inputs, training=None): + x = inputs[self.head_in_index] + x = self.padding1(x, training=training) + x = self.conv1(x, training=training) + x = interpolate( + x, size=self.interpolate_size, data_format=self.data_format + ) + x = self.padding2(x, training=training) + x = self.conv2(x, training=training) + x = self.activation1(x, training=training) + return self.conv3(x, training=training) + + def get_config(self): + config = super().get_config() + config.update( + { + "patch_size": self.patch_size, + "patch_height": self.patch_height, + "patch_width": self.patch_width, + "fusion_hidden_dim": self.fusion_hidden_dim, + "head_hidden_dim": self.head_hidden_dim, + "head_in_index": self.head_in_index, + } + ) + return config + + def compute_output_shape(self, input_shape): + input_shape = input_shape[self.head_in_index] + if self.data_format == "channels_last": + output_shape = [ + input_shape[0], + int(self.patch_height * self.patch_size), + int(self.patch_width * self.patch_size), + 1, + ] + else: + output_shape = [ + input_shape[0], + 1, + int(self.patch_height * self.patch_size), + int(self.patch_width * self.patch_size), + ] + return output_shape diff --git a/keras_hub/src/models/depth_anything/depth_anything_loss.py b/keras_hub/src/models/depth_anything/depth_anything_loss.py new file mode 100644 index 0000000000..e6e4f8a4c1 --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_loss.py @@ -0,0 +1,85 @@ +from keras import ops +from keras.src.losses.losses import LossFunctionWrapper + + +class DepthAnythingLoss(LossFunctionWrapper): + """Computes the DepthAnything loss between `y_true` & `y_pred`. + + This loss is the Scale-Invariant Logarithmic (SiLog) loss, which is + widely used for depth estimation tasks. + + See: [Depth Map Prediction from a Single Image using a Multi-Scale Deep Network](https://arxiv.org/abs/1406.2283) + + Args: + lambd: The weighting factor in the scale-invariant log loss formula. + Defaults to `0.5`. + min_depth: Minimum depth value used to filter `y_pred` and `y_true`. + Defaults to `0.0`. + max_depth: Maximum depth value used to filter `y_pred` and `y_true`. + Defaults to `1.0`. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. `"sum"` sums the loss, + `"sum_over_batch_size"` and `"mean"` sum the loss and divide by the + sample size, and `"mean_with_sample_weight"` sums the loss and + divides by the sum of the sample weights. `"none"` and `None` + perform no aggregation. Defaults to `"sum_over_batch_size"`. + name: Optional name for the instance. + dtype: The dtype of the loss's computations. Defaults to `None`, which + means using `keras.backend.floatx()`. `keras.backend.floatx()` is a + `"float32"` unless set to different value + (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is + provided, then the `compute_dtype` will be utilized. + """ + + def __init__( + self, + lambd=0.5, + min_depth=0.0, + max_depth=1.0, + reduction="sum_over_batch_size", + name="depth_anything_loss", + dtype=None, + ): + if max_depth is None: + max_depth = 1.0 + super().__init__( + silog, + name=name, + reduction=reduction, + dtype=dtype, + lambd=lambd, + min_depth=min_depth, + max_depth=max_depth, + ) + + +def silog(y_true, y_pred, lambd=0.5, min_depth=0.001, max_depth=20.0): + y_pred = ops.convert_to_tensor(y_pred) + y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) + + # Apply the valid mask. + valid_mask = ops.logical_and( + ops.greater_equal(y_true, min_depth), + ops.less_equal(y_true, max_depth), + ) + y_true = ops.multiply(y_true, valid_mask) + y_pred = ops.multiply(y_pred, valid_mask) + + diff_log = ops.where( + valid_mask, + ops.subtract(ops.log(y_true), ops.log(y_pred)), + ops.zeros_like(y_true), + ) + + divisor = ops.sum(ops.cast(valid_mask, y_true.dtype), axis=(1, 2, 3)) + mean_power2_diff_log = ops.divide_no_nan( + ops.sum(ops.power(diff_log, 2), axis=(1, 2, 3)), divisor + ) + power2_mean_diff_log = ops.power( + ops.divide_no_nan(ops.sum(diff_log, axis=(1, 2, 3)), divisor), 2 + ) + return ops.sqrt( + mean_power2_diff_log - ops.multiply(lambd, power2_mean_diff_log) + ) diff --git a/keras_hub/src/models/depth_anything/depth_anything_presets.py b/keras_hub/src/models/depth_anything/depth_anything_presets.py new file mode 100644 index 0000000000..90757d819c --- /dev/null +++ b/keras_hub/src/models/depth_anything/depth_anything_presets.py @@ -0,0 +1,4 @@ +"""DepthAnything model preset configurations.""" + +# Metadata for loading pretrained model weights. +backbone_presets = {} diff --git a/keras_hub/src/models/depth_anything/interpolate.py b/keras_hub/src/models/depth_anything/interpolate.py new file mode 100644 index 0000000000..80adf00d1b --- /dev/null +++ b/keras_hub/src/models/depth_anything/interpolate.py @@ -0,0 +1,62 @@ +from keras import backend +from keras import ops + +from keras_hub.src.utils.keras_utils import standardize_data_format + + +def interpolate(x, size, data_format=None): + """Performs a backend-agnostic version of Torch's `F.interpolate`. + + Args: + x: A 4D image tensor. + size: A tuple of 2 integers, `(height, width)`. + data_format: One of `channels_last` or `channels_first`. + """ + data_format = standardize_data_format(data_format) + if backend.backend() == "jax": + import jax + + if data_format == "channels_first": + x = ops.transpose(x, (0, 2, 3, 1)) + scale = ops.convert_to_tensor( + [ + (size[0] - 1.0) / (x.shape[1] - 1.0), + (size[1] - 1.0) / (x.shape[2] - 1.0), + ] + ) + translation = -(scale / 2.0 - 0.5) + x = jax.image.scale_and_translate( + x, + (x.shape[0], *size, x.shape[-1]), + method="bilinear", + scale=scale, + spatial_dims=(1, 2), + translation=translation, + antialias=False, + ) + if data_format == "channels_first": + x = ops.transpose(x, (0, 3, 1, 2)) + elif backend.backend() == "tensorflow": + import tensorflow as tf + + if data_format == "channels_first": + x = ops.transpose(x, (0, 2, 3, 1)) + x = tf.compat.v1.image.resize( + x, + size=size, + method="bilinear", + align_corners=True, + ) + if data_format == "channels_first": + x = ops.transpose(x, (0, 3, 1, 2)) + elif backend.backend() == "torch": + import torch.nn.functional as F + + if data_format == "channels_last": + x = ops.transpose(x, (0, 3, 1, 2)) + x = F.interpolate(x, size=size, mode="bilinear", align_corners=True) + if data_format == "channels_last": + x = ops.transpose(x, (0, 2, 3, 1)) + else: + raise NotImplementedError(f"Unsupported backend: {backend.backend()}") + return x diff --git a/keras_hub/src/models/depth_estimator.py b/keras_hub/src/models/depth_estimator.py new file mode 100644 index 0000000000..2237b09c2b --- /dev/null +++ b/keras_hub/src/models/depth_estimator.py @@ -0,0 +1,238 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.task import Task + + +class Multiplier(keras.layers.Layer): + def __init__(self, multiplier=None, **kwargs): + super().__init__(**kwargs) + self.multiplier = float(multiplier) if multiplier is not None else None + + def call(self, inputs): + if self.multiplier is not None: + inputs = keras.ops.multiply(inputs, self.multiplier) + return inputs + + def get_config(self): + config = super().get_config() + config.update( + { + "multiplier": self.multiplier, + } + ) + return config + + +@keras_hub_export("keras_hub.models.DepthEstimator") +class DepthEstimator(Task): + """Base class for all depth estimation tasks. + + `DepthEstimator` tasks wrap a `keras_hub.models.Backbone` and + a `keras_hub.models.Preprocessor` to create a model that can be used for + depth estimation. + + To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` + labels where `x` is a RGB image and `y` is a depth map. All `DepthEstimator` + tasks include a `from_preset()` constructor which can be used to load a + pre-trained config and weights. + + Args: + backbone: A `keras_hub.models.Backbone` instance or a `keras.Model`. + preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, + a `keras.Layer` instance, or a callable. If `None` no preprocessing + will be applied to the inputs. + depth_estimation_type: `"relative"` or `"metric"`. The type of depth map + to use. `"relative"` depth maps are up-to-scale, while `"metric"` + depth maps have metric meaning (e.g. in meters). Defaults to + `"relative"`. + min_depth: An optional float. The minimum depth value. This value can + be used to filter out invalid depth values during training. + max_depth: An optional float. The maximum depth value. This value can + be used to filter out invalid depth values during training. Also, + when `depth_estimation_type="metric"`, the model's output will be + scaled to the range `[0, max_depth]`. + + Examples: + + Call `predict()` to run inference. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + depth_estimator = keras_hub.models.DepthEstimator.from_preset( + "depth_anything_v2_small" + ) + depth_estimator.predict(images) + ``` + + Call `fit()` on a single batch. + ```python + # Load preset and train + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + depths = np.random.uniform(0, 10, size=(2, 224, 224)) + depth_estimator = keras_hub.models.DepthEstimator.from_preset( + "depth_anything_v2_small", + depth_estimation_type="metric", + max_depth=10.0, + ) + depth_estimator.fit(x=images, y=depths, batch_size=2) + ``` + + Call `fit()` with custom loss, optimizer and backbone. + ```python + depth_estimator = keras_hub.models.DepthEstimator.from_preset( + "depth_anything_v2_small", + depth_estimation_type="metric", + max_depth=10.0, + ) + depth_estimator.compile( + loss=keras.losses.MeanSquaredError(), + optimizer=keras.optimizers.Adam(5e-5), + ) + depth_estimator.backbone.trainable = False + depth_estimator.fit(x=images, y=depths, batch_size=2) + ``` + + Custom backbone. + ```python + images = np.random.randint(0, 256, size=(2, 224, 224, 3)) + depths = np.random.uniform(0, 10, size=(2, 224, 224)) + image_encoder = keras_hub.models.DINOV2Backbone.from_preset("dinov2_small") + backbone = keras_hub.models.DepthAnythingBackbone( + image_encoder=image_encoder, + patch_size=image_encoder.patch_size, + backbone_hidden_dim=image_encoder.hidden_dim, + reassemble_factors=[4, 2, 1, 0.5], + neck_hidden_dims=[48, 96, 192, 384], + fusion_hidden_dim=64, + head_hidden_dim=32, + head_in_index=-1, + ) + depth_estimator = keras_hub.models.DepthEstimator( + backbone=backbone, + depth_estimation_type="metric", + max_depth=10.0, + ) + depth_estimator.fit(x=images, y=depths, batch_size=2) + ``` + """ + + def __init__( + self, + backbone, + depth_estimation_type, + min_depth=None, + max_depth=None, + preprocessor=None, + **kwargs, + ): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + if depth_estimation_type == "relative": + self.output_activation = keras.layers.ReLU( + dtype=backbone.dtype_policy, + name="output_activation", + ) + elif depth_estimation_type == "metric": + self.output_activation = keras.layers.Activation( + activation="sigmoid", + dtype=backbone.dtype_policy, + name="output_activation", + ) + else: + raise ValueError( + "`depth_estimation_type` should be either `'relative'` or " + "`'metric'`. " + f"Received: depth_estimation_type={depth_estimation_type}." + ) + if max_depth is not None and depth_estimation_type != "metric": + raise ValueError( + "`max_depth` should only be set when " + "`depth_estimation_type='metric'`. " + f"Received: depth_estimation_type={depth_estimation_type}, " + f"max_depth={max_depth}." + ) + self.multiplier = Multiplier( + multiplier=max_depth, dtype=backbone.dtype_policy, name="multiplier" + ) + self.depths = keras.layers.Identity( + dtype=backbone.dtype_policy, name="depths" + ) + + # === Config === + self.depth_estimation_type = depth_estimation_type + self.min_depth = float(min_depth) if min_depth is not None else None + self.max_depth = float(max_depth) if max_depth is not None else None + + # === Functional Model === + inputs = self.backbone.input + depths = self.backbone(inputs) + depths = self.output_activation(depths) + depths = self.multiplier(depths) + depths = self.depths(depths) + outputs = {"depths": depths} + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def get_config(self): + # Backbone serialized in `super` + config = super().get_config() + config.update( + { + "depth_estimation_type": self.depth_estimation_type, + "min_depth": self.min_depth, + "max_depth": self.max_depth, + } + ) + return config + + def compile( + self, + optimizer="auto", + loss="auto", + *, + metrics="auto", + **kwargs, + ): + """Configures the `DepthEstimator` task for training. + + The `DepthEstimator` task extends the default compilation signature of + `keras.Model.compile` with defaults for `optimizer`, `loss`, and + `metrics`. To override these defaults, pass any value + to these arguments during compilation. + + Args: + optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` + instance. Defaults to `"auto"`, which uses the default optimizer + for the given model and task. See `keras.Model.compile` and + `keras.optimizers` for more info on possible `optimizer` values. + loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. + Defaults to `"auto"`, where a `keras.losses.MeanSquaredError` + loss will be applied for the depth estimation task. See + `keras.Model.compile` and `keras.losses` for more info on + possible `loss` values. + metrics: `"auto"`, or a dict of metrics to be evaluated by + the model during training and testing. Defaults to `"auto"`, + where a `keras.metrics.RootMeanSquaredError` will be applied to + track the accuracy of the model during training. See + `keras.Model.compile` and `keras.metrics` for more info on + possible `metrics` values. + **kwargs: See `keras.Model.compile` for a full list of arguments + supported by the compile method. + """ + if optimizer == "auto": + optimizer = keras.optimizers.AdamW(5e-5) + if loss == "auto": + loss = {"depths": keras.losses.MeanSquaredError()} + if metrics == "auto": + metrics = {"depths": keras.metrics.RootMeanSquaredError()} + super().compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) diff --git a/keras_hub/src/models/depth_estimator_preprocessor.py b/keras_hub/src/models/depth_estimator_preprocessor.py new file mode 100644 index 0000000000..63506e5550 --- /dev/null +++ b/keras_hub/src/models/depth_estimator_preprocessor.py @@ -0,0 +1,78 @@ +import keras + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.preprocessor import Preprocessor +from keras_hub.src.utils.tensor_utils import preprocessing_function + + +@keras_hub_export("keras_hub.models.DepthEstimatorPreprocessor") +class DepthEstimatorPreprocessor(Preprocessor): + """Base class for depth estimation preprocessing layers. + + `DepthEstimatorPreprocessor` tasks wraps a + `keras_hub.layers.ImageConverter` to create a preprocessing layer for + depth estimation tasks. It is intended to be paired with a + `keras_hub.models.DepthEstimator` task. + + All `DepthEstimatorPreprocessor` take inputs three inputs, `x`, `y`, and + `sample_weight`. `x`, the first input, should always be included. It can + be a image or batch of images. See examples below. `y` and `sample_weight` + are optional inputs that will be passed through unaltered. Usually, `y` will + be the depths, and `sample_weight` will not be provided. + + The layer will output either `x`, an `(x, y)` tuple if depths were provided, + or an `(x, y, sample_weight)` tuple if depths and sample weight were + provided. `x` will be the input images after all model preprocessing has + been applied. + + All `DepthEstimatorPreprocessor` tasks include a `from_preset()` + constructor which can be used to load a pre-trained config. + You can call the `from_preset()` constructor directly on this base class, in + which case the correct class for your model will be automatically + instantiated. + + Examples. + ```python + preprocessor = keras_hub.models.DepthEstimatorPreprocessor.from_preset( + "depth_anything_v2_small", + ) + + # Resize a single image for DepthAnythingV2 Small. + x = np.random.randint(0, 256, (512, 512, 3)) + x = preprocessor(x) + + # Resize a labeled image. + x = np.random.randint(0, 256, (512, 512, 3)) + y = np.random.uniform(0, 10, size=(512, 512)) + x, y = preprocessor(x, y) + + # Resize a batch of labeled images. + x = [ + np.random.randint(0, 256, (512, 512, 3)), + np.zeros((512, 512, 3)), + ] + y = [ + np.random.uniform(0, 10, size=(512, 512)), + np.random.uniform(0, 10, size=(512, 512)), + ] + x, y = preprocessor(x, y) + + # Use a `tf.data.Dataset`. + ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(2) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + def __init__( + self, + image_converter=None, + **kwargs, + ): + super().__init__(**kwargs) + self.image_converter = image_converter + + @preprocessing_function + def call(self, x, y=None, sample_weight=None): + if self.image_converter: + x = self.image_converter(x) + return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) diff --git a/keras_hub/src/models/dinov2/dinov2_backbone.py b/keras_hub/src/models/dinov2/dinov2_backbone.py index a11d00cef0..de59741d82 100644 --- a/keras_hub/src/models/dinov2/dinov2_backbone.py +++ b/keras_hub/src/models/dinov2/dinov2_backbone.py @@ -1,14 +1,14 @@ from keras import layers from keras_hub.src.api_export import keras_hub_export -from keras_hub.src.models.backbone import Backbone from keras_hub.src.models.dinov2.dinov2_layers import DINOV2Embedding from keras_hub.src.models.dinov2.dinov2_layers import DINOV2Encoder +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone from keras_hub.src.utils.keras_utils import standardize_data_format @keras_hub_export("keras_hub.models.DINOV2Backbone") -class DINOV2Backbone(Backbone): +class DINOV2Backbone(FeaturePyramidBackbone): """DINOV2 core network with hyperparameters. DINOV2 offers a powerful, generalist visual backbone learned entirely from @@ -50,6 +50,8 @@ class DINOV2Backbone(Backbone): embeddings to the actual input shape. Defaults to `(518, 518)`. antialias_in_interpolation: bool. Whether to use antialiasing in the interpolation of the position embeddings. Defaults to `False`. + apply_layernorm: bool. Whether to apply layer normalization to the + outputs of each stage in the feature pyramid. Defaults to `False`. data_format: `None` or str. If specified, either `"channels_last"` or `"channels_first"`. The ordering of the dimensions in the inputs. `"channels_last"` corresponds to inputs with shape @@ -114,6 +116,7 @@ def __init__( image_shape=(224, 224, 3), position_embedding_shape=(518, 518, 3), antialias_in_interpolation=False, + apply_layernorm=False, data_format=None, dtype=None, name=None, @@ -176,10 +179,16 @@ def __init__( ) # === Functional Model === + pyramid_outputs = {} image_input = layers.Input(shape=image_shape, name="images") x = self.embeddings(image_input) - x = self.encoder(x) + pyramid_outputs["Stem"] = x + x, encoder_pyramid_outputs = self.encoder(x) + pyramid_outputs.update(encoder_pyramid_outputs) x = self.layernorm(x) + if apply_layernorm: + for key in pyramid_outputs: + pyramid_outputs[key] = self.layernorm(pyramid_outputs[key]) outputs = x super().__init__( inputs={"images": image_input}, @@ -204,6 +213,8 @@ def __init__( self.image_shape = image_shape self.position_embedding_shape = position_embedding_shape self.antialias_in_interpolation = bool(antialias_in_interpolation) + self.apply_layernorm = apply_layernorm + self.pyramid_outputs = pyramid_outputs def get_config(self): config = super().get_config() @@ -223,6 +234,7 @@ def get_config(self): "image_shape": self.image_shape, "position_embedding_shape": self.position_embedding_shape, "antialias_in_interpolation": self.antialias_in_interpolation, + "apply_layernorm": self.apply_layernorm, } ) return config diff --git a/keras_hub/src/models/dinov2/dinov2_layers.py b/keras_hub/src/models/dinov2/dinov2_layers.py index 4e1ec2362a..d3eb97ce01 100644 --- a/keras_hub/src/models/dinov2/dinov2_layers.py +++ b/keras_hub/src/models/dinov2/dinov2_layers.py @@ -861,10 +861,12 @@ def build(self, input_shape): input_shape = layer.compute_output_shape(input_shape) def call(self, inputs, training=None): + pyramid_outputs = {} x = inputs - for layer in self.layers: + for layer_index, layer in enumerate(self.layers, start=1): x = layer(x, training=training) - return x + pyramid_outputs[f"Stage{str(layer_index)}"] = x + return x, pyramid_outputs def get_config(self): config = super().get_config() @@ -883,4 +885,7 @@ def get_config(self): return config def compute_output_shape(self, input_shape): - return input_shape + pyramid_outputs = {} + for layer_index in range(1, len(self.layers) + 1): + pyramid_outputs[f"Stage{str(layer_index)}"] = input_shape + return input_shape, pyramid_outputs diff --git a/keras_hub/src/utils/transformers/convert_dinov2.py b/keras_hub/src/utils/transformers/convert_dinov2.py index 9e1722f264..dc9d9d6893 100644 --- a/keras_hub/src/utils/transformers/convert_dinov2.py +++ b/keras_hub/src/utils/transformers/convert_dinov2.py @@ -29,6 +29,7 @@ def convert_backbone_config(transformers_config): "image_shape": (image_size, image_size, 3), "position_embedding_shape": (image_size, image_size), "antialias_in_interpolation": antialias_in_interpolation, + "apply_layernorm": transformers_config.get("apply_layernorm", False), } From 64c69030899245d84cbe54919a1b32cc27c8eace Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:23:55 +0800 Subject: [PATCH 2/9] Add DepthAnythingV2 conversion script. --- .../convert_depth_anything_checkpoints.py | 263 ++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 tools/checkpoint_conversion/convert_depth_anything_checkpoints.py diff --git a/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py b/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py new file mode 100644 index 0000000000..8b6776da29 --- /dev/null +++ b/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py @@ -0,0 +1,263 @@ +"""Convert DepthAnything checkpoints. + +export KAGGLE_USERNAME=xxx +export KAGGLE_KEY=xxx + +python tools/checkpoint_conversion/convert_depthanything_checkpoints.py \ + --preset depth_anything_v2_small --upload_uri kaggle://kerashub/depth_anything/keras/depth_anything_v2_small +python tools/checkpoint_conversion/convert_depthanything_checkpoints.py \ + --preset depth_anything_v2_base --upload_uri kaggle://kerashub/depth_anything/keras/depth_anything_v2_base +python tools/checkpoint_conversion/convert_depthanything_checkpoints.py \ + --preset depth_anything_v2_large --upload_uri kaggle://kerashub/depth_anything/keras/depth_anything_v2_large +""" + +import os +import shutil + +import keras +import numpy as np +import torch +from absl import app +from absl import flags +from PIL import Image +from transformers import AutoImageProcessor +from transformers import DepthAnythingForDepthEstimation + +import keras_hub +from keras_hub.src.models.depth_anything.depth_anything_backbone import ( + DepthAnythingBackbone, +) +from keras_hub.src.models.depth_anything.depth_anything_image_converter import ( + DepthAnythingImageConverter, +) +from keras_hub.src.models.dinov2.dinov2_backbone import DINOV2Backbone +from keras_hub.src.utils.transformers import convert_dinov2 +from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader + +FLAGS = flags.FLAGS + +PRESET_MAP = { + "depth_anything_v2_small": "depth-anything/Depth-Anything-V2-Small-hf", + "depth_anything_v2_base": "depth-anything/Depth-Anything-V2-Base-hf", + "depth_anything_v2_large": "depth-anything/Depth-Anything-V2-Large-hf", +} + +flags.DEFINE_string( + "preset", + None, + f"Must be one of {','.join(PRESET_MAP.keys())}", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}"', + required=False, +) + + +def convert_model(hf_model, dtype=None): + dinov2_config = convert_dinov2.convert_backbone_config( + hf_model.config.backbone_config.to_dict() + ) + image_encoder = DINOV2Backbone(**dinov2_config) + model_config = hf_model.config.to_dict() + image_shape = dinov2_config["image_shape"] + # In KerasHub, the stage names are capitalized. + feature_keys = model_config["backbone_config"]["out_features"] + feature_keys = [key.replace("stage", "Stage") for key in feature_keys] + assert model_config["depth_estimation_type"] == "relative" + assert model_config["max_depth"] in (None, 1.0) + return DepthAnythingBackbone( + image_encoder, + image_encoder.patch_size, + image_encoder.hidden_dim, + reassemble_factors=model_config["reassemble_factors"], + neck_hidden_dims=model_config["neck_hidden_sizes"], + fusion_hidden_dim=model_config["fusion_hidden_size"], + head_hidden_dim=model_config["head_hidden_size"], + head_in_index=model_config["head_in_index"], + image_shape=image_shape, + feature_keys=feature_keys, + dtype=dtype, + ) + + +def convert_weights(hf_preset, keras_hub_model, hf_model): + # Convert weights of DINOV2 backbone. + with SafetensorLoader(f"hf://{hf_preset}") as loader: + convert_dinov2.convert_weights( + keras_hub_model.image_encoder, loader, None + ) + + # Get `state_dict` from `hf_model`. + state_dict = hf_model.state_dict() + + # Helper functions. + def port_weights(keras_variable, weight_key, hook_fn=None): + torch_tensor = state_dict[weight_key].cpu().numpy() + if hook_fn: + torch_tensor = hook_fn(torch_tensor, list(keras_variable.shape)) + keras_variable.assign(torch_tensor) + + def port_conv2d(keras_variable, weight_key): + port_weights( + keras_variable.kernel, + f"{weight_key}.weight", + lambda x, s: np.transpose(x, (2, 3, 1, 0)), + ) + if keras_variable.use_bias: + port_weights(keras_variable.bias, f"{weight_key}.bias") + + assert isinstance(keras_hub_model, DepthAnythingBackbone) + + # Convert neck weights. + for i in range(len(keras_hub_model.reassemble_factors)): + # Reassemble stage. + port_conv2d( + keras_hub_model.neck.reassemble_stage[i].projection, + f"neck.reassemble_stage.layers.{i}.projection", + ) + if keras_hub_model.neck.reassemble_stage[i].factor != 1: + port_conv2d( + keras_hub_model.neck.reassemble_stage[i].resize, + f"neck.reassemble_stage.layers.{i}.resize", + ) + # Convs. + port_conv2d(keras_hub_model.neck.convs[i], f"neck.convs.{i}") + # Fusion stage. + port_conv2d( + keras_hub_model.neck.fusion_stage[i].projection, + f"neck.fusion_stage.layers.{i}.projection", + ) + port_conv2d( + keras_hub_model.neck.fusion_stage[i].residual_layer1.convolution1, + f"neck.fusion_stage.layers.{i}.residual_layer1.convolution1", + ) + port_conv2d( + keras_hub_model.neck.fusion_stage[i].residual_layer1.convolution2, + f"neck.fusion_stage.layers.{i}.residual_layer1.convolution2", + ) + port_conv2d( + keras_hub_model.neck.fusion_stage[i].residual_layer2.convolution1, + f"neck.fusion_stage.layers.{i}.residual_layer2.convolution1", + ) + port_conv2d( + keras_hub_model.neck.fusion_stage[i].residual_layer2.convolution2, + f"neck.fusion_stage.layers.{i}.residual_layer2.convolution2", + ) + + # Convert head weights. + port_conv2d(keras_hub_model.head.conv1, "head.conv1") + port_conv2d(keras_hub_model.head.conv2, "head.conv2") + port_conv2d(keras_hub_model.head.conv3, "head.conv3") + + +def convert_image_converter(hf_image_processor): + config = hf_image_processor.to_dict() + image_size = (config["size"]["height"], config["size"]["width"]) + std = config["image_std"] + mean = config["image_mean"] + return DepthAnythingImageConverter( + image_size=image_size, + scale=[1.0 / 255.0 / s for s in std], + offset=[-m / s for m, s in zip(mean, std)], + interpolation="bicubic", # DINOV2 defaults to bicubic resampling. + ) + + +def validate_output( + keras_model, keras_image_converter, hf_model, hf_image_processor +): + config = hf_image_processor.to_dict() + image_size = (config["size"]["height"], config["size"]["width"]) + file = keras.utils.get_file( + origin=("http://images.cocodataset.org/val2017/000000039769.jpg") + ) + image = Image.open(file) + image = image.resize(image_size) + + # Preprocess with hf. + hf_inputs = hf_image_processor(images=image, return_tensors="pt") + hf_preprocessed = hf_inputs["pixel_values"].detach().cpu().numpy() + + # Preprocess with keras. + images = np.expand_dims(np.array(image).astype("float32"), axis=0) + images = keras_image_converter(images) + keras_preprocessed = keras.ops.convert_to_numpy(images) + + # Call with hf. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + hf_inputs["pixel_values"] = torch.from_numpy( + keras.ops.convert_to_numpy( + keras.ops.transpose(keras_preprocessed, (0, 3, 1, 2)) + ) + ) + hf_outputs = hf_model(**hf_inputs) + hf_depths = hf_outputs.predicted_depth.detach().cpu().numpy() + + # Call with keras. + keras_depths = keras_model.predict(images, verbose=0) + # Defaults to "relative" depth estimation. + keras_depths = keras.ops.nn.relu(keras_depths) + keras_depths = keras.ops.convert_to_numpy( + keras.ops.squeeze(keras_depths, axis=-1) + ) + + print("🔶 Keras output:", keras_depths[0]) + print("🔶 HF output:", hf_depths[0]) + modeling_diff = np.mean(np.abs(keras_depths - hf_depths)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean( + np.abs(keras_preprocessed - np.transpose(hf_preprocessed, (0, 2, 3, 1))) + ) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + if FLAGS.preset not in PRESET_MAP.keys(): + raise ValueError( + f"Invalid preset {FLAGS.preset}. Must be one " + f"of {','.join(PRESET_MAP.keys())}" + ) + preset = FLAGS.preset + hf_preset = PRESET_MAP[preset] + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + print(f"🏃 Coverting {preset}") + + # Load huggingface model. + hf_model = DepthAnythingForDepthEstimation.from_pretrained(hf_preset) + hf_image_converter = AutoImageProcessor.from_pretrained(hf_preset) + hf_model.eval() + + keras_model = convert_model(hf_model) + keras_model.summary() + keras_image_converter = convert_image_converter(hf_image_converter) + print("✅ KerasHub model loaded.") + + convert_weights(hf_preset, keras_model, hf_model) + print("✅ Weights converted.") + + validate_output( + keras_model, + keras_image_converter, + hf_model, + hf_image_converter, + ) + print("✅ Output validated.") + + keras_model.save_to_preset(f"./{preset}") + keras_image_converter.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}.") + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + app.run(main) From 19664861a521390292d1af28244d699c28fca3ff Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 1 Sep 2025 10:22:53 +0800 Subject: [PATCH 3/9] Update docstrings. Fix loss `None` object bug. --- .../depth_anything/depth_anything_backbone.py | 114 +++++++++++++--- .../depth_anything_backbone_test.py | 3 - .../depth_anything_depth_estimator_test.py | 3 - .../depth_anything/depth_anything_layers.py | 129 ++++++++++++++++++ .../depth_anything/depth_anything_loss.py | 28 ++-- keras_hub/src/models/depth_estimator.py | 15 +- .../convert_depth_anything_checkpoints.py | 4 - 7 files changed, 250 insertions(+), 46 deletions(-) diff --git a/keras_hub/src/models/depth_anything/depth_anything_backbone.py b/keras_hub/src/models/depth_anything/depth_anything_backbone.py index 6198544375..9206075ab1 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_backbone.py +++ b/keras_hub/src/models/depth_anything/depth_anything_backbone.py @@ -15,17 +15,94 @@ @keras_hub_export("keras_hub.models.DepthAnythingBackbone") class DepthAnythingBackbone(Backbone): + """DepthAnything core network with hyperparameters. + + DepthAnything offers a powerful monocular depth estimation as described in + [Depth Anything V2](https://arxiv.org/abs/2406.09414). + + The default constructor gives a fully customizable, randomly initialized + DepthAnything model with any number of layers, heads, and embedding + dimensions by providing the DINOV2 as the `image_encoder`. To load preset + architectures and weights, use the `from_preset` constructor. + + Args: + image_encoder: The DINOV2 image encoder for encoding the input images. + reassemble_factors: List of float. The reassemble factor for each + feature map from the image encoder. The length of the list must be + equal to the number of feature maps from the image encoder. + neck_hidden_dims: int. The size of the neck hidden state. + fusion_hidden_dim: int. The size of the fusion hidden state. + head_hidden_dim: int. The size of the neck hidden state. + head_in_index: int. The index to select the feature from the neck + features as the input to the head. + feature_keys: List of string. The keys to select the feature maps from + the image encoder. If `None`, all feature maps from the image + encoder will be used. Defaults to `None`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use + for the models computations and weights. Note that some + computations, such as softmax and layer normalization will always + be done a float32 precision regardless of dtype. + + Example: + ```python + # Pretrained DepthAnything model. + input_data = { + "images": np.ones(shape=(1, 518, 518, 3), dtype="float32"), + } + model = keras_hub.models.DINOV2Backbone.from_preset( + "depth_anything_v2_small" + ) + model(input_data) + + # Pretrained DepthAnything model with custom image shape. + input_data = { + "images": np.ones(shape=(1, 224, 224, 3), dtype="float32"), + } + model = keras_hub.models.DINOV2Backbone.from_preset( + "depth_anything_v2_small", image_shape=(224, 224, 3) + ) + model(input_data) + + # Randomly initialized DepthAnything model with custom config. + image_encoder = keras_hub.models.DINOV2Backbone( + patch_size=14, + num_layers=4, + hidden_dim=32, + num_heads=2, + intermediate_dim=128, + image_shape=(224, 224, 3), + position_embedding_shape=(518, 518), + ) + model = keras_hub.models.DepthAnythingBackbone( + image_encoder=image_encoder, + reassemble_factors=[4, 2, 1, 0.5], + neck_hidden_dims=[16, 32, 64, 128], + fusion_hidden_dim=128, + head_hidden_dim=16, + head_in_index=-1, + feature_keys=["Stage1", "Stage2", "Stage3", "Stage4"], + ) + model(input_data) + ``` + """ + def __init__( self, image_encoder, - patch_size, - backbone_hidden_dim, reassemble_factors, neck_hidden_dims, fusion_hidden_dim, head_hidden_dim, head_in_index, - image_shape=(224, 224, 3), feature_keys=None, data_format=None, dtype=None, @@ -48,21 +125,30 @@ def __init__( "`image_encoder.pyramid_outputs` contains " f"{list(image_encoder.pyramid_outputs.keys())}." ) + else: + feature_keys = list(image_encoder.pyramid_outputs.keys()) + if len(reassemble_factors) != len(feature_keys): + raise ValueError( + "The length of `reassemble_factors` must be equal to the " + "length of `feature_keys`. " + f"Received len(reassemble_factors)={len(reassemble_factors)}, " + f"len(feature_keys)={len(feature_keys)}." + ) data_format = standardize_data_format(data_format) + patch_size = image_encoder.patch_size + backbone_hidden_dim = image_encoder.hidden_dim + image_shape = image_encoder.image_shape if data_format == "channels_last": image_size = (image_shape[0], image_shape[1]) else: image_size = (image_shape[1], image_shape[2]) # === Layers === - if feature_keys is None: - pyramid_outputs = image_encoder.pyramid_outputs - else: - pyramid_outputs = { - key: value - for key, value in image_encoder.pyramid_outputs.items() - if key in feature_keys - } + pyramid_outputs = { + key: value + for key, value in image_encoder.pyramid_outputs.items() + if key in feature_keys + } self.feature_extractor = keras.Model( inputs=image_encoder.inputs, outputs=pyramid_outputs, @@ -107,14 +193,11 @@ def __init__( # === Config === self.image_encoder = image_encoder - self.patch_size = patch_size - self.backbone_hidden_dim = backbone_hidden_dim self.reassemble_factors = reassemble_factors self.neck_hidden_dims = neck_hidden_dims self.fusion_hidden_dim = fusion_hidden_dim self.head_hidden_dim = head_hidden_dim self.head_in_index = head_in_index - self.image_shape = image_shape self.feature_keys = feature_keys def get_config(self): @@ -122,14 +205,11 @@ def get_config(self): config.update( { "image_encoder": layers.serialize(self.image_encoder), - "patch_size": self.patch_size, - "backbone_hidden_dim": self.backbone_hidden_dim, "reassemble_factors": self.reassemble_factors, "neck_hidden_dims": self.neck_hidden_dims, "fusion_hidden_dim": self.fusion_hidden_dim, "head_hidden_dim": self.head_hidden_dim, "head_in_index": self.head_in_index, - "image_shape": self.image_shape, "feature_keys": self.feature_keys, } ) diff --git a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py index 21d53c94de..21b024762d 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py +++ b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py @@ -23,14 +23,11 @@ def setUp(self): ) self.init_kwargs = { "image_encoder": image_encoder, - "patch_size": image_encoder.patch_size, - "backbone_hidden_dim": image_encoder.hidden_dim, "reassemble_factors": [4, 2, 1, 0.5], "neck_hidden_dims": [16, 32, 64, 128], "fusion_hidden_dim": 128, "head_hidden_dim": 16, "head_in_index": -1, - "image_shape": (70, 70, 3), "feature_keys": ["Stage1", "Stage2", "Stage3", "Stage4"], } self.input_data = ops.ones((2, 70, 70, 3)) diff --git a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py index 1c8dcb1dd7..48d50554e9 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py +++ b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py @@ -38,14 +38,11 @@ def setUp(self): ) self.backbone = DepthAnythingBackbone( image_encoder=image_encoder, - patch_size=image_encoder.patch_size, - backbone_hidden_dim=image_encoder.hidden_dim, reassemble_factors=[4, 2, 1, 0.5], neck_hidden_dims=[16, 32, 64, 128], fusion_hidden_dim=128, head_hidden_dim=16, head_in_index=-1, - image_shape=(70, 70, 3), feature_keys=["Stage1", "Stage2", "Stage3", "Stage4"], ) self.init_kwargs = { diff --git a/keras_hub/src/models/depth_anything/depth_anything_layers.py b/keras_hub/src/models/depth_anything/depth_anything_layers.py index cec861153a..4b5bb19f12 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_layers.py +++ b/keras_hub/src/models/depth_anything/depth_anything_layers.py @@ -6,6 +6,29 @@ class DepthAnythingTokenToImage(layers.Layer): + """A layer that converts tokens into images. + + Args: + hidden_dim: int. The number of units in the hidden layers. + patch_height: int. The height of each patch. + patch_width: int. The width of each patch. + num_cls_tokens: int. The number of class tokens at the beginning of + the sequence. Defaults to `1`. + num_register_tokens: int. The number of register tokens after the + class tokens. Defaults to `0`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__( self, hidden_dim, @@ -65,6 +88,26 @@ def compute_output_shape(self, input_shape): class DepthAnythingReassembleLayer(layers.Layer): + """A layer that resizes the input images. + + Args: + hidden_dim: int. The number of units in the hidden layers. + factor: float. The resizing factor. If `factor > 1`, the layer upsamples + the input. If `factor < 1`, the layer downsamples the input. If + `factor == 1`, the layer only applies a linear projection. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__(self, hidden_dim, factor, data_format=None, **kwargs): super().__init__(**kwargs) self.hidden_dim = int(hidden_dim) @@ -152,6 +195,23 @@ def compute_output_shape(self, input_shape): class DepthAnythingPreActResidualLayer(layers.Layer): + """A ReLU + Conv2D layer. + + Args: + hidden_dim: int. The number of units in the hidden layers. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__(self, hidden_dim, data_format=None, **kwargs): super().__init__(**kwargs) self.hidden_dim = int(hidden_dim) @@ -229,6 +289,24 @@ def compute_output_shape(self, input_shape): class DepthAnythingFeatureFusionLayer(layers.Layer): + """A layer that fuses the incoming features. + + Args: + hidden_dim: int. The number of units in the hidden layers. + size: tuple of int. The target size of the output feature map. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__(self, hidden_dim, size, data_format=None, **kwargs): super().__init__(**kwargs) self.hidden_dim = int(hidden_dim) @@ -301,6 +379,33 @@ def compute_output_shape(self, input_shape): class DepthAnythingNeck(layers.Layer): + """A DepthAnything neck layer. + + Args: + patch_size: int. The size of one side of each patch. + image_size: tuple of ints. The (height, width) of the input images. + backbone_hidden_dim: int. The number of units in the backbone layers. + neck_hidden_dims: List of int. The number of units in each neck layer. + reassemble_factors: List of float. The resizing factor in each neck + layer. + fusion_hidden_dim: int. The number of units in the fusion layers. + num_cls_tokens: int. The number of class tokens at the beginning of + the sequence. Defaults to `1`. + num_register_tokens: int. The number of register tokens after the + class tokens. Defaults to `0`. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__( self, patch_size, @@ -464,6 +569,30 @@ def get_config(self): class DepthAnythingDepthEstimationHead(layers.Layer): + """A DepthAnything neck layer. + + Args: + patch_size: int. The size of one side of each patch. + patch_height: int. The height of each patch. + patch_width: int. The width of each patch. + hidden_dim: int. The number of units in the hidden layers. + fusion_hidden_dim: int. The number of units in the fusion layers. + head_hidden_dim: int. The number of units in the head layers. + head_in_index: int. The index of the feature map to be used as input + to the head. + data_format: `None` or str. If specified, either `"channels_last"` or + `"channels_first"`. The ordering of the dimensions in the + inputs. `"channels_last"` corresponds to inputs with shape + `(batch_size, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch_size, channels, height, width)`. It defaults to the + `image_data_format` value found in your Keras config file at + `~/.keras/keras.json`. If you never set it, then it will be + `"channels_last"`. + **kwargs: other keyword arguments passed to `keras.layers.Layer`, + including `name`, `dtype` etc. + """ + def __init__( self, patch_size, diff --git a/keras_hub/src/models/depth_anything/depth_anything_loss.py b/keras_hub/src/models/depth_anything/depth_anything_loss.py index e6e4f8a4c1..64092294ba 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_loss.py +++ b/keras_hub/src/models/depth_anything/depth_anything_loss.py @@ -1,3 +1,4 @@ +import keras from keras import ops from keras.src.losses.losses import LossFunctionWrapper @@ -14,9 +15,9 @@ class DepthAnythingLoss(LossFunctionWrapper): lambd: The weighting factor in the scale-invariant log loss formula. Defaults to `0.5`. min_depth: Minimum depth value used to filter `y_pred` and `y_true`. - Defaults to `0.0`. - max_depth: Maximum depth value used to filter `y_pred` and `y_true`. - Defaults to `1.0`. + Defaults to `keras.config.epsilon()`. + max_depth: Optional maximum depth value used to filter `y_pred` and + `y_true`. If not specified, there will be no upper bound. reduction: Type of reduction to apply to the loss. In almost all cases this should be `"sum_over_batch_size"`. Supported options are `"sum"`, `"sum_over_batch_size"`, `"mean"`, @@ -36,14 +37,12 @@ class DepthAnythingLoss(LossFunctionWrapper): def __init__( self, lambd=0.5, - min_depth=0.0, - max_depth=1.0, + min_depth=keras.config.epsilon(), + max_depth=None, reduction="sum_over_batch_size", name="depth_anything_loss", dtype=None, ): - if max_depth is None: - max_depth = 1.0 super().__init__( silog, name=name, @@ -55,15 +54,20 @@ def __init__( ) -def silog(y_true, y_pred, lambd=0.5, min_depth=0.001, max_depth=20.0): +def silog( + y_true, y_pred, lambd=0.5, min_depth=keras.config.epsilon(), max_depth=None +): y_pred = ops.convert_to_tensor(y_pred) y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype) # Apply the valid mask. - valid_mask = ops.logical_and( - ops.greater_equal(y_true, min_depth), - ops.less_equal(y_true, max_depth), - ) + if max_depth is None: + valid_mask = ops.greater_equal(y_true, min_depth) + else: + valid_mask = ops.logical_and( + ops.greater_equal(y_true, min_depth), + ops.less_equal(y_true, max_depth), + ) y_true = ops.multiply(y_true, valid_mask) y_pred = ops.multiply(y_pred, valid_mask) diff --git a/keras_hub/src/models/depth_estimator.py b/keras_hub/src/models/depth_estimator.py index 2237b09c2b..69e89cb2f3 100644 --- a/keras_hub/src/models/depth_estimator.py +++ b/keras_hub/src/models/depth_estimator.py @@ -46,12 +46,13 @@ class DepthEstimator(Task): to use. `"relative"` depth maps are up-to-scale, while `"metric"` depth maps have metric meaning (e.g. in meters). Defaults to `"relative"`. - min_depth: An optional float. The minimum depth value. This value can - be used to filter out invalid depth values during training. - max_depth: An optional float. The maximum depth value. This value can - be used to filter out invalid depth values during training. Also, - when `depth_estimation_type="metric"`, the model's output will be - scaled to the range `[0, max_depth]`. + min_depth: An float representing the minimum depth value. This value can + be used to filter out invalid depth values during training. Defaults + to `keras.config.epsilon()`. + max_depth: An optional float representing the maximum depth value. This + value can be used to filter out invalid depth values during + training. When `depth_estimation_type="metric"`, the model's output + will be scaled to the range `[0, max_depth]`. Examples: @@ -121,7 +122,7 @@ def __init__( self, backbone, depth_estimation_type, - min_depth=None, + min_depth=keras.config.epsilon(), max_depth=None, preprocessor=None, **kwargs, diff --git a/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py b/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py index 8b6776da29..845e06a7e3 100644 --- a/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py +++ b/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py @@ -62,7 +62,6 @@ def convert_model(hf_model, dtype=None): ) image_encoder = DINOV2Backbone(**dinov2_config) model_config = hf_model.config.to_dict() - image_shape = dinov2_config["image_shape"] # In KerasHub, the stage names are capitalized. feature_keys = model_config["backbone_config"]["out_features"] feature_keys = [key.replace("stage", "Stage") for key in feature_keys] @@ -70,14 +69,11 @@ def convert_model(hf_model, dtype=None): assert model_config["max_depth"] in (None, 1.0) return DepthAnythingBackbone( image_encoder, - image_encoder.patch_size, - image_encoder.hidden_dim, reassemble_factors=model_config["reassemble_factors"], neck_hidden_dims=model_config["neck_hidden_sizes"], fusion_hidden_dim=model_config["fusion_hidden_size"], head_hidden_dim=model_config["head_hidden_size"], head_in_index=model_config["head_in_index"], - image_shape=image_shape, feature_keys=feature_keys, dtype=dtype, ) From cb13be9989bab576de640e978671b8b3eed74e33 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 1 Sep 2025 14:06:32 +0800 Subject: [PATCH 4/9] Fix DINOV2 test. --- .../src/models/dinov2/dinov2_backbone.py | 14 +++++++++++++ .../src/models/dinov2/dinov2_backbone_test.py | 20 +++++++++++++------ keras_hub/src/tests/test_case.py | 5 +++-- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/models/dinov2/dinov2_backbone.py b/keras_hub/src/models/dinov2/dinov2_backbone.py index de59741d82..52143030b5 100644 --- a/keras_hub/src/models/dinov2/dinov2_backbone.py +++ b/keras_hub/src/models/dinov2/dinov2_backbone.py @@ -19,6 +19,10 @@ class DINOV2Backbone(FeaturePyramidBackbone): DINOV2 model with any number of layers, heads, and embedding dimensions. To load preset architectures and weights, use the `from_preset` constructor. + Note that this backbone is a Feature Pyramid Backbone that can output + intermediate feature maps from different stages of the model. See the + example below for how to access these feature pyramid outputs. + Note that this backbone supports interpolation of the position embeddings to the input image shape. This is useful when the input image shape is different from the shape used to train the position embeddings. The @@ -97,6 +101,16 @@ class DINOV2Backbone(FeaturePyramidBackbone): position_embedding_shape=(518, 518), ) model(input_data) + + # Accessing feature pyramid outputs. + backbone = keras_hub.models.DINOV2Backbone.from_preset( + "dinov2_base", image_shape=(224, 224, 3) + ) + model = keras.Model( + inputs=backbone.inputs, + outputs=backbone.pyramid_outputs, + ) + features = model(input_data) ``` """ diff --git a/keras_hub/src/models/dinov2/dinov2_backbone_test.py b/keras_hub/src/models/dinov2/dinov2_backbone_test.py index 05fe7c3241..198612d993 100644 --- a/keras_hub/src/models/dinov2/dinov2_backbone_test.py +++ b/keras_hub/src/models/dinov2/dinov2_backbone_test.py @@ -19,10 +19,11 @@ def setUp(self): "layer_scale_init_value": 1.0, "num_register_tokens": 0, "use_swiglu_ffn": False, - "image_shape": (64, 64, 3), + "image_shape": (70, 70, 3), + "name": "dinov2_backbone", } self.input_data = { - "images": ops.ones((2, 64, 64, 3)), + "images": ops.ones((2, 70, 70, 3)), } def test_backbone_basics(self): @@ -30,12 +31,15 @@ def test_backbone_basics(self): image_size = self.init_kwargs["image_shape"][0] hidden_dim = self.init_kwargs["hidden_dim"] sequence_length = (image_size // patch_size) ** 2 + 1 - self.run_backbone_test( + self.run_vision_backbone_test( cls=DINOV2Backbone, init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, sequence_length, hidden_dim), + expected_pyramid_output_keys=["Stem", "Stage1", "Stage2"], + expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3, run_quantization_check=False, + run_data_format_check=False, ) @pytest.mark.large @@ -108,10 +112,11 @@ def setUp(self): "layer_scale_init_value": 1.0, "num_register_tokens": 4, "use_swiglu_ffn": True, - "image_shape": (64, 64, 3), + "image_shape": (70, 70, 3), + "name": "dinov2_backbone", } self.input_data = { - "images": ops.ones((2, 64, 64, 3)), + "images": ops.ones((2, 70, 70, 3)), } def test_backbone_basics(self): @@ -122,12 +127,15 @@ def test_backbone_basics(self): sequence_length = ( (image_size // patch_size) ** 2 + 1 + num_register_tokens ) - self.run_backbone_test( + self.run_vision_backbone_test( cls=DINOV2Backbone, init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, sequence_length, hidden_dim), + expected_pyramid_output_keys=["Stem", "Stage1", "Stage2"], + expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3, run_quantization_check=False, + run_data_format_check=False, ) @pytest.mark.large diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index f70ab78840..633f32cd5b 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -538,10 +538,11 @@ def run_vision_backbone_test( self.assertIsInstance(output_data, dict) self.assertEqual( - list(output_data.keys()), list(backbone.pyramid_outputs.keys()) + sorted(output_data.keys()), + sorted(backbone.pyramid_outputs.keys()), ) self.assertEqual( - list(output_data.keys()), expected_pyramid_output_keys + sorted(output_data.keys()), sorted(expected_pyramid_output_keys) ) # check height and width of each level. for i, (k, v) in enumerate(output_data.items()): From 979d56740c7b3f4e640ac7297edcf33d4fed6342 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:12:30 +0800 Subject: [PATCH 5/9] Fix test. --- .../src/models/depth_anything/depth_anything_backbone_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py index 21b024762d..6770b952ca 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py +++ b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py @@ -20,6 +20,7 @@ def setUp(self): 0, image_shape=(70, 70, 3), apply_layernorm=True, + name="image_encoder", ) self.init_kwargs = { "image_encoder": image_encoder, From d3f84b049e96c4ac776c1e4a898b4b4d5e4b2ae6 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 2 Sep 2025 22:13:42 +0800 Subject: [PATCH 6/9] Use numpy as the inputs. --- .../depth_anything_backbone_test.py | 4 ++-- .../depth_anything_depth_estimator_test.py | 8 ++++---- keras_hub/src/models/dinov2/dinov2_layers.py | 17 +++++++++++------ 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py index 6770b952ca..f6cebb9d96 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py +++ b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py @@ -1,5 +1,5 @@ +import numpy as np import pytest -from keras import ops from keras_hub.src.models.depth_anything.depth_anything_backbone import ( DepthAnythingBackbone, @@ -31,7 +31,7 @@ def setUp(self): "head_in_index": -1, "feature_keys": ["Stage1", "Stage2", "Stage3", "Stage4"], } - self.input_data = ops.ones((2, 70, 70, 3)) + self.input_data = np.ones((2, 70, 70, 3), dtype="float32") def test_backbone_basics(self): self.run_backbone_test( diff --git a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py index 48d50554e9..552e8140f0 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py +++ b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py @@ -1,5 +1,5 @@ +import numpy as np import pytest -from keras import ops from keras_hub.src.models.depth_anything.depth_anything_backbone import ( DepthAnythingBackbone, @@ -30,8 +30,8 @@ def setUp(self): image_shape=(70, 70, 3), apply_layernorm=True, ) - self.images = ops.ones((2, 70, 70, 3)) - self.depths = ops.zeros((2, 70, 70, 1)) + self.images = np.ones((2, 70, 70, 3), dtype="float32") + self.depths = np.zeros((2, 70, 70, 1), dtype="float32") self.image_converter = DepthAnythingImageConverter(image_size=(70, 70)) self.preprocessor = DepthAnythingDepthEstimatorPreprocessor( self.image_converter @@ -82,7 +82,7 @@ def test_saved_model(self): @pytest.mark.extra_large def test_all_presets(self): - images = ops.ones((2, 518, 518, 3)) + images = np.ones((2, 518, 518, 3), dtype="float32") for preset in DepthAnythingDepthEstimator.presets: self.run_preset_test( cls=DepthAnythingDepthEstimator, diff --git a/keras_hub/src/models/dinov2/dinov2_layers.py b/keras_hub/src/models/dinov2/dinov2_layers.py index d3eb97ce01..1cedbe3c7c 100644 --- a/keras_hub/src/models/dinov2/dinov2_layers.py +++ b/keras_hub/src/models/dinov2/dinov2_layers.py @@ -1,3 +1,4 @@ +import keras from keras import backend from keras import config from keras import initializers @@ -276,20 +277,24 @@ def get_config(self): ) return config - def compute_output_shape(self, input_shape): - output_shape = [input_shape[0], None, self.hidden_dim] + def compute_output_shape(self, inputs_shape): + output_shape = [inputs_shape[0], None, self.hidden_dim] if self.data_format == "channels_last": - if input_shape[1] is not None and input_shape[2] is not None: - patch_num = input_shape[1] // self.patch_size + if inputs_shape[1] is not None and inputs_shape[2] is not None: + patch_num = inputs_shape[1] // self.patch_size # 1 is for cls token. output_shape[1] = 1 + self.num_register_tokens + patch_num**2 else: - if input_shape[2] is not None and input_shape[3] is not None: - patch_num = input_shape[2] // self.patch_size + if inputs_shape[2] is not None and inputs_shape[3] is not None: + patch_num = inputs_shape[2] // self.patch_size # 1 is for cls token. output_shape[1] = 1 + self.num_register_tokens + patch_num**2 return output_shape + def compute_output_spec(self, inputs): + output_shape = self.compute_output_shape(inputs.shape) + return keras.KerasTensor(output_shape, dtype=self.compute_dtype) + @staticmethod def _interpolate_position_embeddings( position_embeddings, From 7adaa9471ad13ec4e5940716e284a7a84a761c85 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 5 Sep 2025 11:01:32 +0800 Subject: [PATCH 7/9] Rename the key of the pyramid outputs in DINOV2. --- .../src/models/depth_anything/depth_anything_backbone_test.py | 2 +- .../depth_anything/depth_anything_depth_estimator_test.py | 2 +- keras_hub/src/models/dinov2/dinov2_backbone.py | 2 +- keras_hub/src/models/dinov2/dinov2_backbone_test.py | 4 ++-- keras_hub/src/models/dinov2/dinov2_layers.py | 4 ++-- .../convert_depth_anything_checkpoints.py | 1 - 6 files changed, 7 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py index f6cebb9d96..0d0f5e3ce1 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py +++ b/keras_hub/src/models/depth_anything/depth_anything_backbone_test.py @@ -29,7 +29,7 @@ def setUp(self): "fusion_hidden_dim": 128, "head_hidden_dim": 16, "head_in_index": -1, - "feature_keys": ["Stage1", "Stage2", "Stage3", "Stage4"], + "feature_keys": ["stage1", "stage2", "stage3", "stage4"], } self.input_data = np.ones((2, 70, 70, 3), dtype="float32") diff --git a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py index 552e8140f0..1c69756f3f 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py +++ b/keras_hub/src/models/depth_anything/depth_anything_depth_estimator_test.py @@ -43,7 +43,7 @@ def setUp(self): fusion_hidden_dim=128, head_hidden_dim=16, head_in_index=-1, - feature_keys=["Stage1", "Stage2", "Stage3", "Stage4"], + feature_keys=["stage1", "stage2", "stage3", "stage4"], ) self.init_kwargs = { "backbone": self.backbone, diff --git a/keras_hub/src/models/dinov2/dinov2_backbone.py b/keras_hub/src/models/dinov2/dinov2_backbone.py index 52143030b5..bff2aee16d 100644 --- a/keras_hub/src/models/dinov2/dinov2_backbone.py +++ b/keras_hub/src/models/dinov2/dinov2_backbone.py @@ -196,7 +196,7 @@ def __init__( pyramid_outputs = {} image_input = layers.Input(shape=image_shape, name="images") x = self.embeddings(image_input) - pyramid_outputs["Stem"] = x + pyramid_outputs["stem"] = x x, encoder_pyramid_outputs = self.encoder(x) pyramid_outputs.update(encoder_pyramid_outputs) x = self.layernorm(x) diff --git a/keras_hub/src/models/dinov2/dinov2_backbone_test.py b/keras_hub/src/models/dinov2/dinov2_backbone_test.py index 198612d993..3b1fd2c252 100644 --- a/keras_hub/src/models/dinov2/dinov2_backbone_test.py +++ b/keras_hub/src/models/dinov2/dinov2_backbone_test.py @@ -36,7 +36,7 @@ def test_backbone_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, sequence_length, hidden_dim), - expected_pyramid_output_keys=["Stem", "Stage1", "Stage2"], + expected_pyramid_output_keys=["stem", "stage1", "stage2"], expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3, run_quantization_check=False, run_data_format_check=False, @@ -132,7 +132,7 @@ def test_backbone_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, sequence_length, hidden_dim), - expected_pyramid_output_keys=["Stem", "Stage1", "Stage2"], + expected_pyramid_output_keys=["stem", "stage1", "stage2"], expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3, run_quantization_check=False, run_data_format_check=False, diff --git a/keras_hub/src/models/dinov2/dinov2_layers.py b/keras_hub/src/models/dinov2/dinov2_layers.py index 1cedbe3c7c..564ae4145d 100644 --- a/keras_hub/src/models/dinov2/dinov2_layers.py +++ b/keras_hub/src/models/dinov2/dinov2_layers.py @@ -870,7 +870,7 @@ def call(self, inputs, training=None): x = inputs for layer_index, layer in enumerate(self.layers, start=1): x = layer(x, training=training) - pyramid_outputs[f"Stage{str(layer_index)}"] = x + pyramid_outputs[f"stage{str(layer_index)}"] = x return x, pyramid_outputs def get_config(self): @@ -892,5 +892,5 @@ def get_config(self): def compute_output_shape(self, input_shape): pyramid_outputs = {} for layer_index in range(1, len(self.layers) + 1): - pyramid_outputs[f"Stage{str(layer_index)}"] = input_shape + pyramid_outputs[f"stage{str(layer_index)}"] = input_shape return input_shape, pyramid_outputs diff --git a/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py b/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py index 845e06a7e3..c85baffde6 100644 --- a/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py +++ b/tools/checkpoint_conversion/convert_depth_anything_checkpoints.py @@ -64,7 +64,6 @@ def convert_model(hf_model, dtype=None): model_config = hf_model.config.to_dict() # In KerasHub, the stage names are capitalized. feature_keys = model_config["backbone_config"]["out_features"] - feature_keys = [key.replace("stage", "Stage") for key in feature_keys] assert model_config["depth_estimation_type"] == "relative" assert model_config["max_depth"] in (None, 1.0) return DepthAnythingBackbone( From 37f82c61d69498ac8c39ee6fafc251c1f7c6407d Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 12 Sep 2025 08:17:56 +0800 Subject: [PATCH 8/9] Resolve comments. --- .../models/depth_anything/depth_anything_backbone.py | 4 ++-- keras_hub/src/models/dinov2/dinov2_layers.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/keras_hub/src/models/depth_anything/depth_anything_backbone.py b/keras_hub/src/models/depth_anything/depth_anything_backbone.py index 9206075ab1..77e8a61e36 100644 --- a/keras_hub/src/models/depth_anything/depth_anything_backbone.py +++ b/keras_hub/src/models/depth_anything/depth_anything_backbone.py @@ -58,7 +58,7 @@ class DepthAnythingBackbone(Backbone): input_data = { "images": np.ones(shape=(1, 518, 518, 3), dtype="float32"), } - model = keras_hub.models.DINOV2Backbone.from_preset( + model = keras_hub.models.DepthAnythingBackbone.from_preset( "depth_anything_v2_small" ) model(input_data) @@ -67,7 +67,7 @@ class DepthAnythingBackbone(Backbone): input_data = { "images": np.ones(shape=(1, 224, 224, 3), dtype="float32"), } - model = keras_hub.models.DINOV2Backbone.from_preset( + model = keras_hub.models.DepthAnythingBackbone.from_preset( "depth_anything_v2_small", image_shape=(224, 224, 3) ) model(input_data) diff --git a/keras_hub/src/models/dinov2/dinov2_layers.py b/keras_hub/src/models/dinov2/dinov2_layers.py index 564ae4145d..1124b57a50 100644 --- a/keras_hub/src/models/dinov2/dinov2_layers.py +++ b/keras_hub/src/models/dinov2/dinov2_layers.py @@ -277,16 +277,16 @@ def get_config(self): ) return config - def compute_output_shape(self, inputs_shape): - output_shape = [inputs_shape[0], None, self.hidden_dim] + def compute_output_shape(self, input_shape): + output_shape = [input_shape[0], None, self.hidden_dim] if self.data_format == "channels_last": - if inputs_shape[1] is not None and inputs_shape[2] is not None: - patch_num = inputs_shape[1] // self.patch_size + if input_shape[1] is not None and input_shape[2] is not None: + patch_num = input_shape[1] // self.patch_size # 1 is for cls token. output_shape[1] = 1 + self.num_register_tokens + patch_num**2 else: - if inputs_shape[2] is not None and inputs_shape[3] is not None: - patch_num = inputs_shape[2] // self.patch_size + if input_shape[2] is not None and input_shape[3] is not None: + patch_num = input_shape[2] // self.patch_size # 1 is for cls token. output_shape[1] = 1 + self.num_register_tokens + patch_num**2 return output_shape From 4b63a587e21a00ce69840414b630acb23afa6b85 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Tue, 16 Sep 2025 07:58:22 +0800 Subject: [PATCH 9/9] Reenable the quantization check. --- keras_hub/src/models/dinov2/dinov2_backbone_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras_hub/src/models/dinov2/dinov2_backbone_test.py b/keras_hub/src/models/dinov2/dinov2_backbone_test.py index 3b1fd2c252..ca4edcafc0 100644 --- a/keras_hub/src/models/dinov2/dinov2_backbone_test.py +++ b/keras_hub/src/models/dinov2/dinov2_backbone_test.py @@ -38,7 +38,6 @@ def test_backbone_basics(self): expected_output_shape=(2, sequence_length, hidden_dim), expected_pyramid_output_keys=["stem", "stage1", "stage2"], expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3, - run_quantization_check=False, run_data_format_check=False, ) @@ -134,7 +133,6 @@ def test_backbone_basics(self): expected_output_shape=(2, sequence_length, hidden_dim), expected_pyramid_output_keys=["stem", "stage1", "stage2"], expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3, - run_quantization_check=False, run_data_format_check=False, )