Skip to content

Commit 780919d

Browse files
Add DINOV3 with assistance from the Gemini CLI. (#2444)
* Add DINOV3 with the help from Gemini CLI. * Add tests and docstrings. * Update DINOV3 impls. * Resolves Gemini comments. * Skip the HF conversion test.
1 parent 9fb764a commit 780919d

File tree

13 files changed

+1724
-1
lines changed

13 files changed

+1724
-1
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@
9393
from keras_hub.src.models.dinov2.dinov2_image_converter import (
9494
DINOV2ImageConverter as DINOV2ImageConverter,
9595
)
96+
from keras_hub.src.models.dinov3.dinov3_image_converter import (
97+
DINOV3ImageConverter as DINOV3ImageConverter,
98+
)
9699
from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
97100
EfficientNetImageConverter as EfficientNetImageConverter,
98101
)

keras_hub/api/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@
184184
from keras_hub.src.models.dinov2.dinov2_backbone import (
185185
DINOV2Backbone as DINOV2Backbone,
186186
)
187+
from keras_hub.src.models.dinov3.dinov3_backbone import (
188+
DINOV3Backbone as DINOV3Backbone,
189+
)
187190
from keras_hub.src.models.distil_bert.distil_bert_backbone import (
188191
DistilBertBackbone as DistilBertBackbone,
189192
)

keras_hub/src/models/dinov2/dinov2_layers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,9 @@ def call(self, inputs, training=None):
502502

503503
def get_config(self):
504504
config = super().get_config()
505-
config.update({"hidden_dim": self.hidden_dim})
505+
config.update(
506+
{"hidden_dim": self.hidden_dim, "init_values": self.init_values}
507+
)
506508
return config
507509

508510
def compute_output_shape(self, input_shape):
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from keras_hub.src.models.dinov3.dinov3_backbone import DINOV3Backbone
2+
from keras_hub.src.models.dinov3.dinov3_presets import backbone_presets
3+
from keras_hub.src.utils.preset_utils import register_presets
4+
5+
register_presets(backbone_presets, DINOV3Backbone)
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
from keras import layers
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.dinov3.dinov3_layers import DINOV3Embedding
5+
from keras_hub.src.models.dinov3.dinov3_layers import DINOV3Encoder
6+
from keras_hub.src.models.dinov3.dinov3_layers import (
7+
DINOV3RopePositionEmbedding,
8+
)
9+
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
10+
from keras_hub.src.utils.keras_utils import standardize_data_format
11+
12+
13+
@keras_hub_export("keras_hub.models.DINOV3Backbone")
14+
class DINOV3Backbone(FeaturePyramidBackbone):
15+
"""DINOV3 core network with hyperparameters.
16+
17+
Args:
18+
patch_size: int. The size of each square patch in the input image.
19+
num_layers: int. The number of transformer layers.
20+
hidden_dim: int. The size of the transformer hidden state at the end
21+
of each transformer layer.
22+
num_heads: int. The number of attention heads for each transformer.
23+
intermediate_dim: int. The output dimension of the first Dense layer in
24+
a two-layer feedforward network for each transformer.
25+
layer_scale_init_value: float. The initial value for the layer scale in
26+
the transformer layers. Defaults to `1.0`.
27+
num_register_tokens: int. The number of register tokens to use in the
28+
embedding layer. Defaults to `0`.
29+
use_mask_token: bool. Whether to use a mask token in the embedding
30+
layer. Defaults to `True`.
31+
hidden_activation: str or callable. Activation to use in the MLP.
32+
Defaults to `"gelu"`.
33+
use_gated_mlp: bool. Whether to use Gated MLP layers. Defaults to
34+
`False`.
35+
use_query_bias: bool. Whether to use a bias for the query projection.
36+
Defaults to `True`.
37+
use_key_bias: bool. Whether to use a bias for the key projection.
38+
Defaults to `True`.
39+
use_value_bias: bool. Whether to use a bias for the value projection.
40+
Defaults to `True`.
41+
use_proj_bias: bool. Whether to use a bias for the output projection.
42+
Defaults to `True`.
43+
use_mlp_bias: bool. Whether to use a bias for the dense layers in MLP.
44+
Defaults to `True`.
45+
attention_dropout: float. The dropout rate for the attention
46+
probabilities. Defaults to `0.0`.
47+
drop_path_rate: float. The drop path rate to use. Defaults to `0.0`.
48+
image_shape: tuple. The input shape without the batch size. Defaults to
49+
`(518, 518, 3)`.
50+
rope_theta: float. The base period of the rotary position embeddings.
51+
Defaults to `100.0`.
52+
apply_layernorm: bool. Whether to apply layer normalization to the
53+
outputs of each stage in the feature pyramid. Defaults to `False`.
54+
data_format: `None` or str. If specified, either `"channels_last"` or
55+
`"channels_first"`. The ordering of the dimensions in the
56+
inputs. `"channels_last"` corresponds to inputs with shape
57+
`(batch_size, height, width, channels)`
58+
while `"channels_first"` corresponds to inputs with shape
59+
`(batch_size, channels, height, width)`. It defaults to the
60+
`image_data_format` value found in your Keras config file at
61+
`~/.keras/keras.json`. If you never set it, then it will be
62+
`"channels_last"`.
63+
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
64+
for the models computations and weights. Note that some
65+
computations, such as softmax and layer normalization will always
66+
be done a float32 precision regardless of dtype.
67+
68+
Example:
69+
```python
70+
# Pretrained DINOV3 model.
71+
input_data = {
72+
"images": np.ones(shape=(1, 518, 518, 3), dtype="float32"),
73+
}
74+
model = keras_hub.models.DINOV3Backbone.from_preset(
75+
"dinov3_vit_small_lvd1689m"
76+
)
77+
model(input_data)
78+
79+
# Pretrained DINOV3 model with custom image shape.
80+
input_data = {
81+
"images": np.ones(shape=(1, 224, 224, 3), dtype="float32"),
82+
}
83+
model = keras_hub.models.DINOV3Backbone.from_preset(
84+
"dinov3_vit_small_lvd1689m", image_shape=(224, 224, 3)
85+
)
86+
model(input_data)
87+
88+
# Randomly initialized DINOV3 model with custom config.
89+
model = keras_hub.models.DINOV3Backbone(
90+
patch_size=14,
91+
num_layers=2,
92+
hidden_dim=32,
93+
num_heads=2,
94+
intermediate_dim=128,
95+
image_shape=(224, 224, 3),
96+
)
97+
model(input_data)
98+
99+
# Accessing feature pyramid outputs.
100+
backbone = keras_hub.models.DINOV3Backbone.from_preset(
101+
"dinov3_vit_small_lvd1689m", image_shape=(224, 224, 3)
102+
)
103+
model = keras.Model(
104+
inputs=backbone.inputs,
105+
outputs=backbone.pyramid_outputs,
106+
)
107+
features = model(input_data)
108+
```
109+
"""
110+
111+
def __init__(
112+
self,
113+
patch_size,
114+
num_layers,
115+
hidden_dim,
116+
num_heads,
117+
intermediate_dim,
118+
layer_scale_init_value=1.0,
119+
num_register_tokens=4,
120+
use_mask_token=True,
121+
hidden_activation="gelu",
122+
use_gated_mlp=False,
123+
use_query_bias=True,
124+
use_key_bias=True,
125+
use_value_bias=True,
126+
use_proj_bias=True,
127+
use_mlp_bias=True,
128+
attention_dropout=0.0,
129+
drop_path_rate=0.0,
130+
layer_norm_eps=1e-5,
131+
image_shape=(518, 518, 3),
132+
rope_theta=100.0,
133+
apply_layernorm=False,
134+
data_format=None,
135+
dtype=None,
136+
name=None,
137+
**kwargs,
138+
):
139+
data_format = standardize_data_format(data_format)
140+
141+
prefix = str(name) + "_" if name is not None else ""
142+
143+
# === Layers ===
144+
self.embeddings = DINOV3Embedding(
145+
hidden_dim=hidden_dim,
146+
patch_size=patch_size,
147+
num_register_tokens=num_register_tokens,
148+
use_mask_token=use_mask_token,
149+
data_format=data_format,
150+
dtype=dtype,
151+
name=f"{prefix}embeddings",
152+
)
153+
self.rope_embedding = DINOV3RopePositionEmbedding(
154+
hidden_dim=hidden_dim,
155+
num_heads=num_heads,
156+
rope_theta=rope_theta,
157+
patch_size=patch_size,
158+
dtype=dtype,
159+
name=f"{prefix}rope_embedding",
160+
)
161+
self.encoder = DINOV3Encoder(
162+
num_layers=num_layers,
163+
hidden_dim=hidden_dim,
164+
num_heads=num_heads,
165+
intermediate_dim=intermediate_dim,
166+
layer_scale_init_value=layer_scale_init_value,
167+
hidden_activation=hidden_activation,
168+
use_gated_mlp=use_gated_mlp,
169+
use_query_bias=use_query_bias,
170+
use_key_bias=use_key_bias,
171+
use_value_bias=use_value_bias,
172+
use_proj_bias=use_proj_bias,
173+
use_mlp_bias=use_mlp_bias,
174+
attention_dropout=attention_dropout,
175+
drop_path_rate=drop_path_rate,
176+
layer_norm_eps=layer_norm_eps,
177+
dtype=dtype,
178+
name=f"{prefix}encoder",
179+
)
180+
self.layernorm = layers.LayerNormalization(
181+
epsilon=layer_norm_eps, dtype=dtype, name=f"{prefix}layernorm"
182+
)
183+
184+
# === Functional Model ===
185+
pyramid_outputs = {}
186+
image_input = layers.Input(shape=image_shape, name="pixel_values")
187+
x = self.embeddings(image_input)
188+
pyramid_outputs["stem"] = x
189+
190+
position_embeddings = self.rope_embedding(image_input)
191+
num_prefix_tokens = 1 + num_register_tokens
192+
193+
x, encoder_pyramid_outputs = self.encoder(
194+
x,
195+
position_embeddings=position_embeddings,
196+
num_prefix_tokens=num_prefix_tokens,
197+
)
198+
pyramid_outputs.update(encoder_pyramid_outputs)
199+
x = self.layernorm(x)
200+
if apply_layernorm:
201+
for key in pyramid_outputs:
202+
pyramid_outputs[key] = self.layernorm(pyramid_outputs[key])
203+
outputs = x
204+
super().__init__(
205+
inputs={"pixel_values": image_input},
206+
outputs=outputs,
207+
dtype=dtype,
208+
name=name,
209+
**kwargs,
210+
)
211+
212+
# === Config ===
213+
self.patch_size = int(patch_size)
214+
self.num_layers = int(num_layers)
215+
self.hidden_dim = int(hidden_dim)
216+
self.num_heads = int(num_heads)
217+
self.intermediate_dim = int(intermediate_dim)
218+
self.layer_scale_init_value = float(layer_scale_init_value)
219+
self.num_register_tokens = int(num_register_tokens)
220+
self.use_mask_token = bool(use_mask_token)
221+
self.hidden_activation = hidden_activation
222+
self.use_gated_mlp = bool(use_gated_mlp)
223+
self.use_query_bias = bool(use_query_bias)
224+
self.use_key_bias = bool(use_key_bias)
225+
self.use_value_bias = bool(use_value_bias)
226+
self.use_proj_bias = bool(use_proj_bias)
227+
self.use_mlp_bias = bool(use_mlp_bias)
228+
self.attention_dropout = float(attention_dropout)
229+
self.drop_path_rate = float(drop_path_rate)
230+
self.layer_norm_eps = float(layer_norm_eps)
231+
self.image_shape = image_shape
232+
self.rope_theta = rope_theta
233+
self.apply_layernorm = apply_layernorm
234+
self.pyramid_outputs = pyramid_outputs
235+
236+
def get_config(self):
237+
config = super().get_config()
238+
config.update(
239+
{
240+
"patch_size": self.patch_size,
241+
"num_layers": self.num_layers,
242+
"hidden_dim": self.hidden_dim,
243+
"num_heads": self.num_heads,
244+
"intermediate_dim": self.intermediate_dim,
245+
"num_register_tokens": self.num_register_tokens,
246+
"use_mask_token": self.use_mask_token,
247+
"layer_scale_init_value": self.layer_scale_init_value,
248+
"hidden_activation": self.hidden_activation,
249+
"use_gated_mlp": self.use_gated_mlp,
250+
"use_query_bias": self.use_query_bias,
251+
"use_key_bias": self.use_key_bias,
252+
"use_value_bias": self.use_value_bias,
253+
"use_proj_bias": self.use_proj_bias,
254+
"use_mlp_bias": self.use_mlp_bias,
255+
"attention_dropout": self.attention_dropout,
256+
"drop_path_rate": self.drop_path_rate,
257+
"layer_norm_eps": self.layer_norm_eps,
258+
"image_shape": self.image_shape,
259+
"rope_theta": self.rope_theta,
260+
"apply_layernorm": self.apply_layernorm,
261+
}
262+
)
263+
return config

0 commit comments

Comments
 (0)