|  | 
|  | 1 | +from keras import layers | 
|  | 2 | + | 
|  | 3 | +from keras_hub.src.api_export import keras_hub_export | 
|  | 4 | +from keras_hub.src.models.dinov3.dinov3_layers import DINOV3Embedding | 
|  | 5 | +from keras_hub.src.models.dinov3.dinov3_layers import DINOV3Encoder | 
|  | 6 | +from keras_hub.src.models.dinov3.dinov3_layers import ( | 
|  | 7 | +    DINOV3RopePositionEmbedding, | 
|  | 8 | +) | 
|  | 9 | +from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone | 
|  | 10 | +from keras_hub.src.utils.keras_utils import standardize_data_format | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +@keras_hub_export("keras_hub.models.DINOV3Backbone") | 
|  | 14 | +class DINOV3Backbone(FeaturePyramidBackbone): | 
|  | 15 | +    """DINOV3 core network with hyperparameters. | 
|  | 16 | +
 | 
|  | 17 | +    Args: | 
|  | 18 | +        patch_size: int. The size of each square patch in the input image. | 
|  | 19 | +        num_layers: int. The number of transformer layers. | 
|  | 20 | +        hidden_dim: int. The size of the transformer hidden state at the end | 
|  | 21 | +            of each transformer layer. | 
|  | 22 | +        num_heads: int. The number of attention heads for each transformer. | 
|  | 23 | +        intermediate_dim: int. The output dimension of the first Dense layer in | 
|  | 24 | +            a two-layer feedforward network for each transformer. | 
|  | 25 | +        layer_scale_init_value: float. The initial value for the layer scale in | 
|  | 26 | +            the transformer layers. Defaults to `1.0`. | 
|  | 27 | +        num_register_tokens: int. The number of register tokens to use in the | 
|  | 28 | +            embedding layer. Defaults to `0`. | 
|  | 29 | +        use_mask_token: bool. Whether to use a mask token in the embedding | 
|  | 30 | +            layer. Defaults to `True`. | 
|  | 31 | +        hidden_activation: str or callable. Activation to use in the MLP. | 
|  | 32 | +            Defaults to `"gelu"`. | 
|  | 33 | +        use_gated_mlp: bool. Whether to use Gated MLP layers. Defaults to | 
|  | 34 | +            `False`. | 
|  | 35 | +        use_query_bias: bool. Whether to use a bias for the query projection. | 
|  | 36 | +            Defaults to `True`. | 
|  | 37 | +        use_key_bias: bool. Whether to use a bias for the key projection. | 
|  | 38 | +            Defaults to `True`. | 
|  | 39 | +        use_value_bias: bool. Whether to use a bias for the value projection. | 
|  | 40 | +            Defaults to `True`. | 
|  | 41 | +        use_proj_bias: bool. Whether to use a bias for the output projection. | 
|  | 42 | +            Defaults to `True`. | 
|  | 43 | +        use_mlp_bias: bool. Whether to use a bias for the dense layers in MLP. | 
|  | 44 | +            Defaults to `True`. | 
|  | 45 | +        attention_dropout: float. The dropout rate for the attention | 
|  | 46 | +            probabilities. Defaults to `0.0`. | 
|  | 47 | +        drop_path_rate: float. The drop path rate to use. Defaults to `0.0`. | 
|  | 48 | +        image_shape: tuple. The input shape without the batch size. Defaults to | 
|  | 49 | +            `(518, 518, 3)`. | 
|  | 50 | +        rope_theta: float. The base period of the rotary position embeddings. | 
|  | 51 | +            Defaults to `100.0`. | 
|  | 52 | +        apply_layernorm: bool. Whether to apply layer normalization to the | 
|  | 53 | +            outputs of each stage in the feature pyramid. Defaults to `False`. | 
|  | 54 | +        data_format: `None` or str. If specified, either `"channels_last"` or | 
|  | 55 | +            `"channels_first"`. The ordering of the dimensions in the | 
|  | 56 | +            inputs. `"channels_last"` corresponds to inputs with shape | 
|  | 57 | +            `(batch_size, height, width, channels)` | 
|  | 58 | +            while `"channels_first"` corresponds to inputs with shape | 
|  | 59 | +            `(batch_size, channels, height, width)`. It defaults to the | 
|  | 60 | +            `image_data_format` value found in your Keras config file at | 
|  | 61 | +            `~/.keras/keras.json`. If you never set it, then it will be | 
|  | 62 | +            `"channels_last"`. | 
|  | 63 | +        dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use | 
|  | 64 | +            for the models computations and weights. Note that some | 
|  | 65 | +            computations, such as softmax and layer normalization will always | 
|  | 66 | +            be done a float32 precision regardless of dtype. | 
|  | 67 | +
 | 
|  | 68 | +    Example: | 
|  | 69 | +    ```python | 
|  | 70 | +    # Pretrained DINOV3 model. | 
|  | 71 | +    input_data = { | 
|  | 72 | +        "images": np.ones(shape=(1, 518, 518, 3), dtype="float32"), | 
|  | 73 | +    } | 
|  | 74 | +    model = keras_hub.models.DINOV3Backbone.from_preset( | 
|  | 75 | +        "dinov3_vit_small_lvd1689m" | 
|  | 76 | +    ) | 
|  | 77 | +    model(input_data) | 
|  | 78 | +
 | 
|  | 79 | +    # Pretrained DINOV3 model with custom image shape. | 
|  | 80 | +    input_data = { | 
|  | 81 | +        "images": np.ones(shape=(1, 224, 224, 3), dtype="float32"), | 
|  | 82 | +    } | 
|  | 83 | +    model = keras_hub.models.DINOV3Backbone.from_preset( | 
|  | 84 | +        "dinov3_vit_small_lvd1689m", image_shape=(224, 224, 3) | 
|  | 85 | +    ) | 
|  | 86 | +    model(input_data) | 
|  | 87 | +
 | 
|  | 88 | +    # Randomly initialized DINOV3 model with custom config. | 
|  | 89 | +    model = keras_hub.models.DINOV3Backbone( | 
|  | 90 | +        patch_size=14, | 
|  | 91 | +        num_layers=2, | 
|  | 92 | +        hidden_dim=32, | 
|  | 93 | +        num_heads=2, | 
|  | 94 | +        intermediate_dim=128, | 
|  | 95 | +        image_shape=(224, 224, 3), | 
|  | 96 | +    ) | 
|  | 97 | +    model(input_data) | 
|  | 98 | +
 | 
|  | 99 | +    # Accessing feature pyramid outputs. | 
|  | 100 | +    backbone = keras_hub.models.DINOV3Backbone.from_preset( | 
|  | 101 | +        "dinov3_vit_small_lvd1689m", image_shape=(224, 224, 3) | 
|  | 102 | +    ) | 
|  | 103 | +    model = keras.Model( | 
|  | 104 | +        inputs=backbone.inputs, | 
|  | 105 | +        outputs=backbone.pyramid_outputs, | 
|  | 106 | +    ) | 
|  | 107 | +    features = model(input_data) | 
|  | 108 | +    ``` | 
|  | 109 | +    """ | 
|  | 110 | + | 
|  | 111 | +    def __init__( | 
|  | 112 | +        self, | 
|  | 113 | +        patch_size, | 
|  | 114 | +        num_layers, | 
|  | 115 | +        hidden_dim, | 
|  | 116 | +        num_heads, | 
|  | 117 | +        intermediate_dim, | 
|  | 118 | +        layer_scale_init_value=1.0, | 
|  | 119 | +        num_register_tokens=4, | 
|  | 120 | +        use_mask_token=True, | 
|  | 121 | +        hidden_activation="gelu", | 
|  | 122 | +        use_gated_mlp=False, | 
|  | 123 | +        use_query_bias=True, | 
|  | 124 | +        use_key_bias=True, | 
|  | 125 | +        use_value_bias=True, | 
|  | 126 | +        use_proj_bias=True, | 
|  | 127 | +        use_mlp_bias=True, | 
|  | 128 | +        attention_dropout=0.0, | 
|  | 129 | +        drop_path_rate=0.0, | 
|  | 130 | +        layer_norm_eps=1e-5, | 
|  | 131 | +        image_shape=(518, 518, 3), | 
|  | 132 | +        rope_theta=100.0, | 
|  | 133 | +        apply_layernorm=False, | 
|  | 134 | +        data_format=None, | 
|  | 135 | +        dtype=None, | 
|  | 136 | +        name=None, | 
|  | 137 | +        **kwargs, | 
|  | 138 | +    ): | 
|  | 139 | +        data_format = standardize_data_format(data_format) | 
|  | 140 | + | 
|  | 141 | +        prefix = str(name) + "_" if name is not None else "" | 
|  | 142 | + | 
|  | 143 | +        # === Layers === | 
|  | 144 | +        self.embeddings = DINOV3Embedding( | 
|  | 145 | +            hidden_dim=hidden_dim, | 
|  | 146 | +            patch_size=patch_size, | 
|  | 147 | +            num_register_tokens=num_register_tokens, | 
|  | 148 | +            use_mask_token=use_mask_token, | 
|  | 149 | +            data_format=data_format, | 
|  | 150 | +            dtype=dtype, | 
|  | 151 | +            name=f"{prefix}embeddings", | 
|  | 152 | +        ) | 
|  | 153 | +        self.rope_embedding = DINOV3RopePositionEmbedding( | 
|  | 154 | +            hidden_dim=hidden_dim, | 
|  | 155 | +            num_heads=num_heads, | 
|  | 156 | +            rope_theta=rope_theta, | 
|  | 157 | +            patch_size=patch_size, | 
|  | 158 | +            dtype=dtype, | 
|  | 159 | +            name=f"{prefix}rope_embedding", | 
|  | 160 | +        ) | 
|  | 161 | +        self.encoder = DINOV3Encoder( | 
|  | 162 | +            num_layers=num_layers, | 
|  | 163 | +            hidden_dim=hidden_dim, | 
|  | 164 | +            num_heads=num_heads, | 
|  | 165 | +            intermediate_dim=intermediate_dim, | 
|  | 166 | +            layer_scale_init_value=layer_scale_init_value, | 
|  | 167 | +            hidden_activation=hidden_activation, | 
|  | 168 | +            use_gated_mlp=use_gated_mlp, | 
|  | 169 | +            use_query_bias=use_query_bias, | 
|  | 170 | +            use_key_bias=use_key_bias, | 
|  | 171 | +            use_value_bias=use_value_bias, | 
|  | 172 | +            use_proj_bias=use_proj_bias, | 
|  | 173 | +            use_mlp_bias=use_mlp_bias, | 
|  | 174 | +            attention_dropout=attention_dropout, | 
|  | 175 | +            drop_path_rate=drop_path_rate, | 
|  | 176 | +            layer_norm_eps=layer_norm_eps, | 
|  | 177 | +            dtype=dtype, | 
|  | 178 | +            name=f"{prefix}encoder", | 
|  | 179 | +        ) | 
|  | 180 | +        self.layernorm = layers.LayerNormalization( | 
|  | 181 | +            epsilon=layer_norm_eps, dtype=dtype, name=f"{prefix}layernorm" | 
|  | 182 | +        ) | 
|  | 183 | + | 
|  | 184 | +        # === Functional Model === | 
|  | 185 | +        pyramid_outputs = {} | 
|  | 186 | +        image_input = layers.Input(shape=image_shape, name="pixel_values") | 
|  | 187 | +        x = self.embeddings(image_input) | 
|  | 188 | +        pyramid_outputs["stem"] = x | 
|  | 189 | + | 
|  | 190 | +        position_embeddings = self.rope_embedding(image_input) | 
|  | 191 | +        num_prefix_tokens = 1 + num_register_tokens | 
|  | 192 | + | 
|  | 193 | +        x, encoder_pyramid_outputs = self.encoder( | 
|  | 194 | +            x, | 
|  | 195 | +            position_embeddings=position_embeddings, | 
|  | 196 | +            num_prefix_tokens=num_prefix_tokens, | 
|  | 197 | +        ) | 
|  | 198 | +        pyramid_outputs.update(encoder_pyramid_outputs) | 
|  | 199 | +        x = self.layernorm(x) | 
|  | 200 | +        if apply_layernorm: | 
|  | 201 | +            for key in pyramid_outputs: | 
|  | 202 | +                pyramid_outputs[key] = self.layernorm(pyramid_outputs[key]) | 
|  | 203 | +        outputs = x | 
|  | 204 | +        super().__init__( | 
|  | 205 | +            inputs={"pixel_values": image_input}, | 
|  | 206 | +            outputs=outputs, | 
|  | 207 | +            dtype=dtype, | 
|  | 208 | +            name=name, | 
|  | 209 | +            **kwargs, | 
|  | 210 | +        ) | 
|  | 211 | + | 
|  | 212 | +        # === Config === | 
|  | 213 | +        self.patch_size = int(patch_size) | 
|  | 214 | +        self.num_layers = int(num_layers) | 
|  | 215 | +        self.hidden_dim = int(hidden_dim) | 
|  | 216 | +        self.num_heads = int(num_heads) | 
|  | 217 | +        self.intermediate_dim = int(intermediate_dim) | 
|  | 218 | +        self.layer_scale_init_value = float(layer_scale_init_value) | 
|  | 219 | +        self.num_register_tokens = int(num_register_tokens) | 
|  | 220 | +        self.use_mask_token = bool(use_mask_token) | 
|  | 221 | +        self.hidden_activation = hidden_activation | 
|  | 222 | +        self.use_gated_mlp = bool(use_gated_mlp) | 
|  | 223 | +        self.use_query_bias = bool(use_query_bias) | 
|  | 224 | +        self.use_key_bias = bool(use_key_bias) | 
|  | 225 | +        self.use_value_bias = bool(use_value_bias) | 
|  | 226 | +        self.use_proj_bias = bool(use_proj_bias) | 
|  | 227 | +        self.use_mlp_bias = bool(use_mlp_bias) | 
|  | 228 | +        self.attention_dropout = float(attention_dropout) | 
|  | 229 | +        self.drop_path_rate = float(drop_path_rate) | 
|  | 230 | +        self.layer_norm_eps = float(layer_norm_eps) | 
|  | 231 | +        self.image_shape = image_shape | 
|  | 232 | +        self.rope_theta = rope_theta | 
|  | 233 | +        self.apply_layernorm = apply_layernorm | 
|  | 234 | +        self.pyramid_outputs = pyramid_outputs | 
|  | 235 | + | 
|  | 236 | +    def get_config(self): | 
|  | 237 | +        config = super().get_config() | 
|  | 238 | +        config.update( | 
|  | 239 | +            { | 
|  | 240 | +                "patch_size": self.patch_size, | 
|  | 241 | +                "num_layers": self.num_layers, | 
|  | 242 | +                "hidden_dim": self.hidden_dim, | 
|  | 243 | +                "num_heads": self.num_heads, | 
|  | 244 | +                "intermediate_dim": self.intermediate_dim, | 
|  | 245 | +                "num_register_tokens": self.num_register_tokens, | 
|  | 246 | +                "use_mask_token": self.use_mask_token, | 
|  | 247 | +                "layer_scale_init_value": self.layer_scale_init_value, | 
|  | 248 | +                "hidden_activation": self.hidden_activation, | 
|  | 249 | +                "use_gated_mlp": self.use_gated_mlp, | 
|  | 250 | +                "use_query_bias": self.use_query_bias, | 
|  | 251 | +                "use_key_bias": self.use_key_bias, | 
|  | 252 | +                "use_value_bias": self.use_value_bias, | 
|  | 253 | +                "use_proj_bias": self.use_proj_bias, | 
|  | 254 | +                "use_mlp_bias": self.use_mlp_bias, | 
|  | 255 | +                "attention_dropout": self.attention_dropout, | 
|  | 256 | +                "drop_path_rate": self.drop_path_rate, | 
|  | 257 | +                "layer_norm_eps": self.layer_norm_eps, | 
|  | 258 | +                "image_shape": self.image_shape, | 
|  | 259 | +                "rope_theta": self.rope_theta, | 
|  | 260 | +                "apply_layernorm": self.apply_layernorm, | 
|  | 261 | +            } | 
|  | 262 | +        ) | 
|  | 263 | +        return config | 
0 commit comments