From a1794ccc18e3ec6d06811ed4438f7fbcfd8d5a7c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 11 Sep 2025 11:17:36 -0700 Subject: [PATCH 01/12] Add DINOv3 ConvNeXt weights. --- timm/models/_hub.py | 9 +++++++-- timm/models/convnext.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 408d2b8faf..0dbf377e52 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -457,13 +457,18 @@ def push_to_hf_hub( ) -def generate_readme(model_card: dict, model_name: str): - tags = model_card.get('tags', None) or ['image-classification', 'timm', 'transformers'] +def generate_readme( + model_card: dict, + model_name: str, + task_name: str = 'image-classification', +): + tags = model_card.get('tags', None) or [task_name, 'timm', 'transformers'] readme_text = "---\n" if tags: readme_text += "tags:\n" for t in tags: readme_text += f"- {t}\n" + readme_text += f"pipeline_tag: {task_name}\n" readme_text += f"library_name: {model_card.get('library_name', 'timm')}\n" readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" if 'license_name' in model_card: diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 509c49a2c6..cdc34eba2f 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -1061,6 +1061,33 @@ def _cfgv2(url='', **kwargs): mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024), + # NOTE dinov3 convnext weights are under a specific license, and downstream outputs must be shared with this + # https://ai.meta.com/resources/models-and-libraries/dinov3-license/ + 'convnext_tiny.dinov3_lvd1689m': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, + num_classes=0, + license='dinov3', + ), + 'convnext_small.dinov3_lvd1689m': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, + num_classes=0, + license='dinov3', + ), + 'convnext_base.dinov3_lvd1689m': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, + num_classes=0, + license='dinov3', + ), + 'convnext_large.dinov3_lvd1689m': _cfg( + hf_hub_id='timm/', + crop_pct=1.0, + num_classes=0, + license='dinov3', + ), + "test_convnext.r160_in1k": _cfg( hf_hub_id='timm/', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), From fe843da68e9ab1d119f47f9864839026c1bb4908 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Sep 2025 11:29:53 -0700 Subject: [PATCH 02/12] DINOv3 rotary position embedding impl --- timm/layers/__init__.py | 2 + timm/layers/pos_embed_sincos.py | 315 +++++++++++++++++++++++++++++++- 2 files changed, 309 insertions(+), 8 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 3b8e6d4c85..0c52123c76 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -133,7 +133,9 @@ RotaryEmbedding, RotaryEmbeddingCat, RotaryEmbeddingMixed, + RotaryEmbeddingDinoV3, get_mixed_freqs, + create_rope_embed, ) from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernel diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 9c6dc6e6ae..39cb404bf7 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -214,24 +214,73 @@ def forward(self, x): def rot(x): + # x: [ x0 x1 x2 x3 x4 x5] + # out: [-x1 x0 -x3 x2 -x5 x4] return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) -def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): - if sin_emb.ndim == 3: - return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x) - return x * cos_emb + rot(x) * sin_emb +def rope_rotate_half(x: torch.Tensor) -> torch.Tensor: + # x: [ x0 x1 x2 x3 x4 x5] + # out: [-x3 -x4 -x5 x0 x1 x2] + x1, x2 = x.chunk(2, dim=-1) + return torch.cat([-x2, x1], dim=-1) -def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): +def apply_rot_embed( + x: torch.Tensor, + emb: torch.Tensor, + half: bool = False, +) -> torch.Tensor: + # x: [..., D], eg [x0, x1, x2, x3, x4, x5] + if half: + # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] + # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2 + # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2] + return x * cos_emb + rope_rotate_half(x) * sin_emb + else: + # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2] + # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2] + # rot(x): eg [-x1, x0, -x3, x2, -x5, x4] + return x * cos_emb + rot(x) * sin_emb + + +def apply_rot_embed_list( + x: List[torch.Tensor], + emb: torch.Tensor, + half: bool = False +) -> List[torch.Tensor]: if isinstance(x, torch.Tensor): x = [x] - return [t * cos_emb + rot(t) * sin_emb for t in x] + # x: [..., D], eg [x0, x1, x2, x3, x4, x5] + if half: + # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] + # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2 + # rope_rotate_half(x): eg [-x3, -x4, -x5, x0, x1, x2] + return [t * cos_emb + rope_rotate_half(t) * sin_emb for t in x] + else: + # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2] + # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2] + # rot(x): eg [-x1, x0, -x3, x2, -x5, x4] + return [t * cos_emb + rot(t) * sin_emb for t in x] -def apply_rot_embed_cat(x: torch.Tensor, emb): +def apply_rot_embed_cat( + x: torch.Tensor, + emb: torch.Tensor, + half: bool = False +) -> torch.Tensor: sin_emb, cos_emb = emb.tensor_split(2, -1) - return x * cos_emb + rot(x) * sin_emb + # x: [..., D], eg [x0, x1, x2, x3, x4, x5] + if half: + # sin: [..., D], eg [sin0, sin1, sin2, sin0, sin1, sin2] + # cos: [..., D], eg [cos0, cos1, cos2, cos0, cos1, cos2 + # rope_rotate_half(x), eg [-x3, -x4, -x5, x0, x1, x2] + return x * cos_emb + rope_rotate_half(x) * sin_emb + else: + # sin: [..., D], eg [sin0, sin0, sin1, sin1, sin2, sin2] + # cos: [..., D], eg [cos0, cos0, cos1, cos1, cos2, cos2] + # rot(x), eg [-x1, x0, -x3, x2, -x5, x4] + return x * cos_emb + rot(x) * sin_emb def apply_keep_indices_nlc( @@ -834,3 +883,253 @@ def forward(self, x): def no_weight_decay(self): """Exclude frequency parameters from weight decay.""" return {'freqs'} + + +class RotaryEmbeddingDinoV3(nn.Module): + """RoPE for timm DinoV3 port, numerically matching original. + + Math is aligned to original DinoV3 RopePositionEmbedding at https://github.com/facebookresearch/dinov3: + - 0.5-centered coords normalized by H/W (or min/max), mapped to [-1,1] + - training-time augmentations (shift/jitter/rescale) + - periods schedule equals Rope's temperature (base) or min/max period + """ + + def __init__( + self, + dim: int, + temperature: Optional[float] = 100.0, + min_period: Optional[float] = 0.5, + max_period: Optional[float] = 90., + feat_shape: Optional[List[int]] = None, + ref_feat_shape: Optional[List[int]] = None, + normalize_coords: str = "separate", # 'min', 'max', 'separate' + grid_offset: float = 0.0, + grid_indexing: str = "ij", + rotate_half: bool = True, + shift_coords: Optional[float] = None, + jitter_coords: Optional[float] = None, # interpreted as factor J >= 1 + rescale_coords: Optional[float] = None, # interpreted as factor R >= 1 + ): + super().__init__() + + # Dimensions / output format + self.dim = dim # equal to head_dim for most vit applications + self.rotate_half = rotate_half + + # Period schedule parameters + self.temperature = float(temperature) + self.min_period = min_period + self.max_period = max_period + + # Coord processing + augs + self.normalize_coords = normalize_coords + self.shift_coords = shift_coords + self.jitter_coords = jitter_coords + self.rescale_coords = rescale_coords + self.aug_active = any([a is not None for a in [self.shift_coords, self.jitter_coords, self.rescale_coords]]) + + # Grid config + self.feat_shape = feat_shape + self.ref_feat_shape = ref_feat_shape + self.grid_offset = grid_offset + self.grid_indexing = grid_indexing + + # Precompute periods + periods = self._compute_periods() + self.register_buffer("periods", periods, persistent=False) + + if feat_shape is not None: + self._cache_embed(feat_shape) + else: + self.register_buffer("pos_embed_cached", None, persistent=False) + self.feat_shape = None + + def _compute_periods(self, device='cpu', dtype=torch.float32) -> torch.Tensor: + """Construct periods from either min/max or temperature.""" + dim = self.dim // 4 + + if self.min_period is not None and self.max_period is not None: + exponents = torch.linspace(0, 1, dim, dtype=torch.float32) + periods = self.min_period * ((self.max_period / self.min_period) ** exponents) + else: + if self.temperature is None: + raise ValueError("Provide either min/max periods or `temperature`.") + exponents = 2.0 * torch.arange(dim, device=device, dtype=dtype) / (self.dim // 2) + periods = self.temperature ** exponents + + # NOTE: original has periods downcast to bfloat16 in persistent buffers, so loaded models + # BTW orignal and timm might differ a bit here + + return periods + + def _make_coords( + self, + height: int, + width: int, + device: torch.device = 'cpu', + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """Make coordinate grid matching offset and normalization of original. + Returns: coords with shape (HW, 2) in [-1, 1]. + """ + # 0.5-centered indices with optional offset + coords_h = torch.arange(0.5, height, device=device, dtype=dtype) + self.grid_offset + coords_w = torch.arange(0.5, width, device=device, dtype=dtype) + self.grid_offset + + # Normalization denominators + if self.normalize_coords == "max": + denom = float(max(height, width)) + h_denom = denom + w_denom = denom + elif self.normalize_coords == "min": + denom = float(min(height, width)) + h_denom = denom + w_denom = denom + elif self.normalize_coords == "separate": + h_denom = float(height) + w_denom = float(width) + else: + raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") + + # Normalize to [0, 1] + coords_h = coords_h / h_denom + coords_w = coords_w / w_denom + + # Create grid then map to [-1, 1] + if self.grid_indexing == "xy": + grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy") + coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order) + else: + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2) + coords = coords.flatten(0, 1) # (HW, 2) + coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1] + return coords + + def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor: + """Apply shift/jitter/rescale train time augmentations.""" + if not self.training or not self.aug_active: + return coords + + device = coords.device + dtype = coords.dtype + + # Shift per-axis in [-s, +s] + if self.shift_coords is not None: + shift = float(self.shift_coords) + shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, +shift) + coords = coords + shift_hw[None, :] + + # Jitter: per-axis log-uniform factor in [1/J, J] + if self.jitter_coords is not None: + jitter_factor = float(self.jitter_coords) + if jitter_factor <= 0: + raise ValueError("jitter_coords must be > 0 (interpreted as multiplicative factor).") + jitter_max = math.log(jitter_factor) + jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, +jitter_max).exp() + coords = coords * jitter_hw[None, :] + + # Rescale: shared scalar log-uniform factor in [1/R, R] + if self.rescale_coords is not None: + rescale_factor = float(self.rescale_coords) + if rescale_factor <= 0: + raise ValueError("rescale_coords must be > 0 (interpreted as multiplicative factor).") + rescale_max = math.log(rescale_factor) + rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, +rescale_max).exp() + coords = coords * rescale + + return coords + + def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Return sin/cos embeddings with either 'half' or 'interleaved' layout.""" + # coords: (HW, 2); periods: (dim) + dim = self.dim // 4 + device = self.periods.device + dtype = self.periods.dtype + assert self.periods.numel() == dim + + # NOTE this is a slightly later device/dtype switch than original + coords = coords[:, :, None].to(device=device, dtype=dtype) + angles = 2 * math.pi * coords / self.periods[None, None, :] + angles = angles.flatten(1) # (HW, dim // 2) + + if self.rotate_half: + # Tile (half layout) (HW, dim // 2) -> (HW, dim) + angles = angles.tile(2) + else: + # Interleaved layout (HW, dim // 2) -> (HW, dim) + angles = angles.repeat_interleave(2, dim=-1) + + sin = torch.sin(angles) + cos = torch.cos(angles) + return sin, cos + + def _create_embed(self, feat_shape: List[int], no_aug: bool = False) -> torch.Tensor: + H, W = feat_shape + coords = self._make_coords(H, W) # (HW, 2) + if not no_aug: + coords = self._apply_coord_augs(coords) + sin, cos = self._get_pos_embed_from_coords(coords) # 2 * (HW, dim) + rope_embed = torch.cat([sin, cos], dim=-1) # (HW, 2*dim) + return rope_embed + + def _cache_embed(self, feat_shape: List[int]): + rope_embed = self._create_embed(feat_shape, no_aug=True) # create non-augmented embeds for cache + self.register_buffer("pos_embed_cached", rope_embed, persistent=False) + self.feat_shape = feat_shape + + def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: + """Generate rope_embed matching DINOv3 RopePositionEmbedding numerics. + + Returns: (HW, num_heads, 2 * head_dim) with last dim = [sin, cos] cat. + """ + if shape is not None: + rope_embed = self._create_embed(shape) + else: + need_create = self.pos_embed_cached is None or (self.training and self.aug_active) + if need_create: + assert self.feat_shape is not None, 'feature shape must be cached on create' + rope_embed = self._create_embed(self.feat_shape) + else: + rope_embed = self.pos_embed_cached + + return rope_embed + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Get and apply rotary embeddings to x""" + # assuming channel-first tensor where spatial dim are >= 2 + pos_embed = self.get_embed(x.shape[2:]) + return apply_rot_embed_cat(x, pos_embed, half=self.rotate_half) + + +def create_rope_embed( + rope_type: str = 'cat', + dim: int = 768, + num_heads: int = 12, + **kwargs +) -> nn.Module: + """Factory function for creating rotary position embeddings. + + Args: + rope_type: Type of RoPE to create. Options: + - 'base': Basic RotaryEmbedding + - 'cat': RotaryEmbeddingCat (concatenated sin/cos) + - 'mixed': RotaryEmbeddingMixed (learnable per-depth frequencies) + - 'dinov3': RotaryEmbeddingDinoV3 (with coordinate transforms) + dim: Total embedding dimension + num_heads: Number of attention heads + **kwargs: Additional arguments passed to the specific RoPE class + + Returns: + Rotary embedding module + """ + if rope_type == 'base': + return RotaryEmbedding(dim=dim // num_heads, **kwargs) + elif rope_type == 'cat': + return RotaryEmbeddingCat(dim=dim // num_heads, **kwargs) + elif rope_type == 'mixed': + # Mixed requires depth parameter, generates differing embeddings per layer and head + return RotaryEmbeddingMixed(dim=dim, num_heads=num_heads, **kwargs) + elif rope_type == 'dinov3': + return RotaryEmbeddingDinoV3(dim=dim // num_heads, **kwargs) + else: + raise ValueError(f"Unknown RoPE type: {rope_type}") From a9018fffbc1092ec166eea98eb0cd4df55e4fee5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Sep 2025 12:32:10 -0700 Subject: [PATCH 03/12] Remove hard-coded min/max period values that were used for testing, clarify comment --- timm/layers/pos_embed_sincos.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 39cb404bf7..b2cee4fe9b 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -898,8 +898,8 @@ def __init__( self, dim: int, temperature: Optional[float] = 100.0, - min_period: Optional[float] = 0.5, - max_period: Optional[float] = 90., + min_period: Optional[float] = None, + max_period: Optional[float] = None, feat_shape: Optional[List[int]] = None, ref_feat_shape: Optional[List[int]] = None, normalize_coords: str = "separate", # 'min', 'max', 'separate' @@ -957,8 +957,8 @@ def _compute_periods(self, device='cpu', dtype=torch.float32) -> torch.Tensor: exponents = 2.0 * torch.arange(dim, device=device, dtype=dtype) / (self.dim // 2) periods = self.temperature ** exponents - # NOTE: original has periods downcast to bfloat16 in persistent buffers, so loaded models - # BTW orignal and timm might differ a bit here + # NOTE: The original dinv3 model weights have periods downcast to bfloat16 in persistent buffers, + # loaded models will differ a bit vs timm as periods is not persistent and generated in float32 by default return periods From e0592d466917c781e876274750548e0eba0444b0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Sep 2025 16:38:05 -0700 Subject: [PATCH 04/12] Add DINOv3 model defs, add EVA additions to support DINOv3. Currently testing... --- timm/layers/attention.py | 8 +- timm/layers/mlp.py | 4 + timm/models/eva.py | 500 ++++++++++++++++++++++++++++++++++----- 3 files changed, 455 insertions(+), 57 deletions(-) diff --git a/timm/layers/attention.py b/timm/layers/attention.py index 3cd084d3c8..3fbbec342a 100644 --- a/timm/layers/attention.py +++ b/timm/layers/attention.py @@ -121,6 +121,7 @@ def __init__( qk_norm: bool = False, scale_norm: bool = False, proj_bias: bool = True, + rotate_half: bool = False, ): """Initialize the Attention module. @@ -136,6 +137,7 @@ def __init__( norm_layer: Normalization layer constructor to use for QK and scale normalization qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer scale_norm: Enable normalization (scaling) of attention output with norm_layer + rotate_half: Use 'half' ROPE layout instead of default 'interleaved' """ super().__init__() if scale_norm or qk_norm: @@ -148,6 +150,7 @@ def __init__( self.scale = head_dim ** -0.5 self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() + self.rotate_half = rotate_half if qkv_fused: self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) @@ -196,8 +199,9 @@ def forward( if rope is not None: npt = self.num_prefix_tokens - q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v) - k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v) + half = getattr(self, 'rotate_half', False) + q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v) + k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v) if self.fused_attn: x = F.scaled_dot_product_attention( diff --git a/timm/layers/mlp.py b/timm/layers/mlp.py index 188c6b530b..a8b1cc0d2e 100644 --- a/timm/layers/mlp.py +++ b/timm/layers/mlp.py @@ -119,6 +119,7 @@ def __init__( norm_layer=None, bias=True, drop=0., + align_to=0, ): super().__init__() out_features = out_features or in_features @@ -126,6 +127,9 @@ def __init__( bias = to_2tuple(bias) drop_probs = to_2tuple(drop) + if align_to: + hidden_features = hidden_features + (-hidden_features % align_to) + self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0]) self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() diff --git a/timm/models/eva.py b/timm/models/eva.py index bcfa3ee2cb..2842238c4b 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -2,6 +2,13 @@ EVA ViT from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636 +This file contains a number of ViT variants the utilise ROPE position embeddings, SwiGLU and other additions: + * EVA & EVA02 model implementations that evolved from BEiT, additional models in vision_transformer.py. + * `timm` original SBB ViT w/ ROPE position embeddings + * Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181) + * ROPE-ViT from Naver AI (https://arxiv.org/abs/2403.13298) + * DINOv3 from META AI Research (https://arxiv.org/abs/2508.10104) + @article{EVA, title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale}, author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang, @@ -35,11 +42,21 @@ organization={Springer} } -This file contains a number of ViT variants the utilise ROPE position embeddings, SwiGLU and other additions: - * EVA & EVA02 model implementations that evolved from BEiT, additional models in vision_transformer.py. - * `timm` original SBB ViT w/ ROPE position embeddings - * Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181) - * ROPE-ViT from Naver AI (https://arxiv.org/abs/2403.13298) +@article{simeoni2025dinov3, + title={{DINOv3}}, + author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime + and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l + and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e + and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie + and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick + and Bojanowski, Piotr}, + year={2025}, + eprint={2508.10104}, + url={https://arxiv.org/abs/2508.10104}, +} + +DINOv3 code was a modification of existing EVA model and support modules, so licensed under Apache-2.0 like timm. +Weights from META remain under DINOv3 License (https://ai.meta.com/resources/models-and-libraries/dinov3-license/). Modifications by / Copyright 2023 Ross Wightman, original copyrights below """ @@ -63,8 +80,7 @@ LayerNorm, DropPath, PatchDropoutWithIndices, - RotaryEmbeddingCat, - RotaryEmbeddingMixed, + create_rope_embed, apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, @@ -73,6 +89,7 @@ global_pool_nlc, to_2tuple, use_fused_attn, + maybe_add_mask, AttentionRope, AttentionPoolLatent, ) @@ -103,6 +120,7 @@ def __init__( norm_layer: Optional[Callable] = None, qk_norm: bool = False, scale_norm: bool = True, + rotate_half: bool = False, ): """ Args: @@ -119,6 +137,7 @@ def __init__( norm_layer: Normalization layer constructor to use for QK and scale normalization qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer scale_norm: Enable normalization (scaling) of attention output with norm_layer + rotate_half: Use half rotation layout instead of interleaved """ super().__init__() if scale_norm or qk_norm: @@ -132,6 +151,7 @@ def __init__( self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() self.qkv_bias_separate = qkv_bias_separate + self.rotate_half = rotate_half if qkv_fused: self.qkv = nn.Linear(dim, attn_dim * 3, bias=False) @@ -194,8 +214,9 @@ def forward( if rope is not None: npt = self.num_prefix_tokens - q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope)], dim=2).type_as(v) - k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope)], dim=2).type_as(v) + half = getattr(self, 'rotate_half', False) + q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v) + k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v) if self.fused_attn: x = F.scaled_dot_product_attention( @@ -206,10 +227,7 @@ def forward( else: q = q * self.scale attn = (q @ k.transpose(-2, -1)) - - if attn_mask is not None: - attn_mask = attn_mask.to(torch.bool) - attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + attn = maybe_add_mask(attn, attn_mask) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) @@ -232,10 +250,12 @@ def __init__( qkv_fused: bool = True, mlp_ratio: float = 4., swiglu_mlp: bool = False, + swiglu_align_to: int = 0, scale_mlp: bool = False, scale_attn_inner: bool = False, num_prefix_tokens: int = 1, attn_type: str = 'eva', + rotate_half: bool = False, proj_drop: float = 0., attn_drop: float = 0., drop_path: float = 0., @@ -280,6 +300,7 @@ def __init__( attn_head_dim=attn_head_dim, norm_layer=norm_layer, scale_norm=scale_attn_inner, + rotate_half=rotate_half, ) self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -287,16 +308,17 @@ def __init__( self.norm2 = norm_layer(dim) hidden_features = int(dim * mlp_ratio) if swiglu_mlp: - if scale_mlp: - # when norm in SwiGLU used, an impl with separate fc for gate & x is used + if scale_mlp or swiglu_align_to: + # when norm in SwiGLU used or alignment enabled, an impl with separate fc for gate & x is used self.mlp = SwiGLU( in_features=dim, hidden_features=hidden_features, norm_layer=norm_layer if scale_mlp else None, drop=proj_drop, + align_to=swiglu_align_to, ) else: - # w/o any extra norm, an impl with packed weights is used, matches existing GluMLP + # w/o any extra norm, an impl with packed weights is used self.mlp = GluMlp( in_features=dim, hidden_features=hidden_features * 2, @@ -341,7 +363,9 @@ def __init__( qkv_fused: bool = True, mlp_ratio: float = 4., attn_type: str = 'eva', + rotate_half: bool = False, swiglu_mlp: bool = False, + swiglu_aligh_to: int = 0, scale_mlp: bool = False, scale_attn_inner: bool = False, num_prefix_tokens: int = 1, @@ -387,6 +411,7 @@ def __init__( attn_head_dim=attn_head_dim, norm_layer=norm_layer, scale_norm=scale_attn_inner, + rotate_half=rotate_half, ) self.norm1 = norm_layer(dim) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -455,6 +480,7 @@ def __init__( qkv_fused: bool = True, mlp_ratio: float = 4., swiglu_mlp: bool = False, + swiglu_align_to: int = 0, scale_mlp: bool = False, scale_attn_inner: bool = False, attn_type: str = 'eva', @@ -471,10 +497,11 @@ def __init__( no_embed_class: bool = False, use_abs_pos_emb: bool = True, use_rot_pos_emb: bool = False, - rope_mixed_mode: bool = False, + rope_type: Optional[str] = 'cat', rope_grid_offset: float = 0., rope_grid_indexing: str = 'ij', rope_temperature: float = 10000., + rope_rotate_half: bool = False, use_post_norm: bool = False, use_pre_transformer_norm: bool = False, use_post_transformer_norm: Optional[bool] = None, @@ -517,10 +544,11 @@ def __init__( no_embed_class: Don't include position embeddings for class (or reg) tokens use_abs_pos_emb: Use absolute (learned) positional embeddings use_rot_pos_emb: Use rotary position embeddings - rope_mixed_mode: Use mixed mode ROPE with per-layer learnable frequencies + rope_type: Type of RoPE to use ('cat', 'mixed', 'dinov3', etc.). rope_grid_offset: Offset for rotary position embedding grid rope_grid_indexing: Indexing mode for rotary position embeddings ('ij' or 'xy') rope_temperature: Temperature parameter for ROPE frequency computation + rope_rotate_half: Use half rotation layout (rotate D/2 dims), else use interleaved rotation layout use_post_norm: Use post-norm transformer block type use_pre_transformer_norm: Use normalization layer before transformer blocks use_post_transformer_norm: Use normalization layer after transformer blocks @@ -581,32 +609,35 @@ def __init__( else: self.patch_drop = None + self.rope_mixed = False if use_rot_pos_emb: ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None - if rope_mixed_mode: + + # Setup RoPE kwargs + rope_kwargs = dict( + dim=embed_dim, + num_heads=num_heads, + feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, + temperature=rope_temperature, + grid_indexing=rope_grid_indexing, + ) + if rope_type == 'mixed': + rope_kwargs.update(dict(depth=depth)) self.rope_mixed = True - # Mixed mode to supports depth-dependent frequencies - self.rope = RotaryEmbeddingMixed( - dim=embed_dim, - depth=depth, - num_heads=num_heads, - temperature=rope_temperature, - feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, - grid_indexing=rope_grid_indexing, - ) - else: - self.rope_mixed = False - self.rope = RotaryEmbeddingCat( - dim=embed_dim // num_heads, - temperature=rope_temperature, - in_pixels=False, - feat_shape=None if dynamic_img_size else self.patch_embed.grid_size, + elif rope_type == 'dinov3': + rope_kwargs.update(dict( + grid_offset=rope_grid_offset, ref_feat_shape=ref_feat_shape, + )) + else: # 'cat' or 'base' + rope_kwargs.update(dict( + in_pixels=False, grid_offset=rope_grid_offset, - grid_indexing=rope_grid_indexing, - ) + ref_feat_shape=ref_feat_shape, + )) + + self.rope = create_rope_embed(rope_type=rope_type, **rope_kwargs) else: - self.rope_mixed = False self.rope = None self.norm_pre = norm_layer(embed_dim) if activate_pre_norm else nn.Identity() @@ -621,9 +652,11 @@ def __init__( qkv_fused=qkv_fused, mlp_ratio=mlp_ratio, swiglu_mlp=swiglu_mlp, + swiglu_align_to=swiglu_align_to, scale_mlp=scale_mlp, scale_attn_inner=scale_attn_inner, attn_type=attn_type, + rotate_half=rope_rotate_half, num_prefix_tokens=self.num_prefix_tokens, proj_drop=proj_drop_rate, attn_drop=attn_drop_rate, @@ -635,7 +668,7 @@ def __init__( self.feature_info = [ dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)] - self.norm = norm_layer(embed_dim) if activate_post_norm else nn.Identity() + self.norm = norm_layer(embed_dim) if activate_post_norm else nn.Identity() if global_pool == 'map': self.attn_pool = AttentionPoolLatent( @@ -1074,7 +1107,9 @@ def checkpoint_filter_fn( prefix = 'visual.' else: prefix = '' - mim_weights = prefix + 'mask_token' in state_dict + + dinov3_weights = 'storage_tokens' in state_dict + mim_weights = not dinov3_weights and prefix + 'mask_token' in state_dict no_qkv = prefix + 'blocks.0.attn.q_proj.weight' in state_dict len_prefix = len(prefix) @@ -1088,6 +1123,40 @@ def checkpoint_filter_fn( # fixed embedding no need to load buffer from checkpoint continue + if dinov3_weights: + if any([k.endswith(f) for f in ['.periods', '.bias_mask', 'mask_token']]): + # discard unused/non-persistent/pretrain only params + continue + if k.startswith('local_cls_norm'): + # discard, only used for 7b dinov3 pretrain w/ local crops + continue + if k.endswith('qkv.bias'): + q_bias_k = k.replace('qkv.bias', 'q_bias') + try: + # the distilled b,l,h models ended up with all zero biases, so timm + # has both qkv_bias=True and qkv_bias=False impl, test which + model.get_parameter(q_bias_k) + except Exception as e: + print(e) + # skip as target model has no bias parameter + continue + # split bias into components and skip the k as its supposed to be fixed at 0 + qv, kv, vv = v.chunk(3, dim=-1) + out_dict[q_bias_k] = qv + out_dict[k.replace('qkv.bias', 'v_bias')] = vv + continue + k = k.replace('ls1.gamma', 'gamma_1') # match EVA ls naming + k = k.replace('ls2.gamma', 'gamma_2') # match EVA ls naming + k = k.replace('storage_tokens', 'reg_token') # rename storage to existing register naming + + elif mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'): + if k == 'norm.weight' or k == 'norm.bias': + # try moving norm -> fc norm on fine-tune, probably a better starting point than new init + k = k.replace('norm', 'fc_norm') + else: + # skip pretrain mask token & head weights + continue + if 'patch_embed.proj.weight' in k: _, _, H, W = model.patch_embed.proj.weight.shape if v.shape[-1] != W or v.shape[-2] != H: @@ -1120,14 +1189,6 @@ def checkpoint_filter_fn( k = k.replace('q_bias', 'q_proj.bias') k = k.replace('v_bias', 'v_proj.bias') - if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'): - if k == 'norm.weight' or k == 'norm.bias': - # try moving norm -> fc norm on fine-tune, probably a better starting point than new init - k = k.replace('norm', 'fc_norm') - else: - # skip pretrain mask token & head weights - continue - out_dict[k] = v return out_dict @@ -1150,7 +1211,7 @@ def _create_eva(variant: str, pretrained: bool = False, **kwargs) -> Eva: if use_naflex is None: use_naflex = _USE_NAFLEX_DEFAULT if use_naflex: - # Import here to avoid circular imports + # Import here to avoid circular import from .naflexvit import _create_naflexvit_from_eva return _create_naflexvit_from_eva(variant, pretrained, **kwargs) @@ -1567,6 +1628,94 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), + + # DINOv3 weights are under a specific license with redistribution terms, please see + # https://github.com/facebookresearch/dinov3/blob/main/LICENSE.md + 'vit_small_patch16_dinov3_224.lvdm_1689m': _cfg( + # hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_small_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( + # hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_small_plus_patch16_dinov3_224.lvdm_1689m': _cfg( + # hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_small_plus_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( + # hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_base_patch16_dinov3_224.lvdm_1689m': _cfg( + #hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_base_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( + #hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_large_patch16_dinov3_224.lvdm_1689m': _cfg( + # hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_large_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( + # hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_large_patch16_dinov3_224.sat_493m': _cfg( + # hf_hub_id='timm/', + mean=(0.430, 0.411, 0.296), + std=(0.213, 0.156, 0.143), + num_classes=0, + license='dinov3', + ), + 'vit_huge_plus_patch16_dinov3_224.lvdm_1689m': _cfg( + # hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_7b_patch16_dinov3_224.lvdm_1689m': _cfg( + # hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_classes=0, + license='dinov3', + ), + 'vit_7b_patch16_dinov3_224.sat_493m': _cfg( + # hf_hub_id='timm/', + mean=(0.430, 0.411, 0.296), + std=(0.213, 0.156, 0.143), + num_classes=0, + license='dinov3', + ), + }) @@ -2302,7 +2451,7 @@ def vit_small_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva: use_rot_pos_emb=True, rope_grid_indexing='xy', rope_temperature=10.0, - rope_mixed_mode=True, + rope_type='mixed' ) model = _create_eva('vit_small_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2326,7 +2475,7 @@ def vit_base_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva: use_rot_pos_emb=True, rope_grid_indexing='xy', rope_temperature=10.0, - rope_mixed_mode=True, + rope_type='mixed' ) model = _create_eva('vit_base_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2350,7 +2499,7 @@ def vit_large_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva: use_rot_pos_emb=True, rope_grid_indexing='xy', rope_temperature=10.0, - rope_mixed_mode=True, + rope_type='mixed' ) model = _create_eva('vit_large_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2450,7 +2599,7 @@ def vit_small_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> use_rot_pos_emb=True, rope_grid_indexing='xy', rope_temperature=10.0, - rope_mixed_mode=True, + rope_type='mixed' ) model = _create_eva('vit_small_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -2476,7 +2625,7 @@ def vit_base_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> E use_rot_pos_emb=True, rope_grid_indexing='xy', rope_temperature=10.0, - rope_mixed_mode=True, + rope_type='mixed' ) model = _create_eva('vit_base_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2501,8 +2650,249 @@ def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> use_rot_pos_emb=True, rope_grid_indexing='xy', rope_temperature=10.0, - rope_mixed_mode=True, + rope_type='mixed' ) model = _create_eva('vit_large_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model + +@register_model +def vit_small_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + qkv_bias=False, + init_values=1.0e-05, # layer-scale + rope_type='dinov3', + rope_temperature=100, + #rope_rescale_coords=2, # haven't added to interface + rope_rotate_half=True, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_eva('vit_small_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + qkv_bias=True, + init_values=1.0e-05, # layer-scale + rope_type='dinov3', + rope_temperature=100, + #rope_rescale_coords=2, # haven't added to interface + rope_rotate_half=True, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_eva('vit_small_patch16_dinov3_qkvb_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_plus_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + qkv_bias=False, + init_values=1.0e-05, # layer-scale + rope_type='dinov3', + rope_temperature=100, + #rope_rescale_coords=2, # haven't added to interface + rope_rotate_half=True, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + swiglu_mlp=True, + swiglu_align_to=8, + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_eva('vit_small_plus_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_plus_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=6, + qkv_bias=True, + init_values=1.0e-05, # layer-scale + rope_type='dinov3', + rope_temperature=100, + #rope_rescale_coords=2, # haven't added to interface + rope_rotate_half=True, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + swiglu_mlp=True, + swiglu_align_to=8, + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_eva('vit_small_plus_patch16_dinov3_qkvb_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + qkv_bias=False, + init_values=1.0e-05, # layer-scale + rope_type='dinov3', + rope_temperature=100, + #rope_rescale_coords=2, # haven't added to interface + rope_rotate_half=True, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_eva('vit_base_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: + # DINOv3 Base variant w/ qkv_bias enabled (zero'd in weights) + model_args = dict( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + qkv_bias=True, + init_values=1.0e-05, # layer-scale + rope_type='dinov3', + rope_temperature=100, + #rope_rescale_coords=2, # haven't added to interface + rope_rotate_half=True, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_eva('vit_base_patch16_dinov3_qkvb_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + qkv_bias=False, + init_values=1.0e-5, # layer-scale + rope_type='dinov3', + rope_temperature=100, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + rope_rotate_half=True, + #rope_rescale_coords=2, # haven't added to interface + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_eva('vit_large_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=768, + depth=24, + num_heads=16, + qkv_bias=True, + init_values=1.0e-5, # layer-scale + rope_type='dinov3', + rope_temperature=100, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + rope_rotate_half=True, + #rope_rescale_coords=2, # haven't added to interface + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + model = _create_eva('vit_large_patch16_dinov3_qkvb_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_plus_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=20, + qkv_bias=False, + init_values=1.0e-5, # layer-scale + rope_type='dinov3', + rope_temperature=100, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + rope_rotate_half=True, + swiglu_mlp=True, + swiglu_align_to=8, + #rope_rescale_coords=2, # haven't added to interface + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + + model = _create_eva('vit_huge_plus_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_7b_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: + model_args = dict( + patch_size=16, + embed_dim=4096, + depth=40, + num_heads=32, + qkv_bias=False, + mlp_ratio=2, + init_values=1.0e-5, # layer-scale + rope_type='dinov3', + rope_temperature=100, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + rope_rotate_half=True, + swiglu_mlp=True, + swiglu_align_to=64, + #rope_rescale_coords=2, # haven't added to interface + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + + model = _create_eva('vit_7b_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model From f133fb3f5c26de9e81f0b579739d661e7f682fc6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 15 Sep 2025 17:13:30 -0700 Subject: [PATCH 05/12] Exclude 7b from unit tests --- tests/test_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 028440179d..af8c4087cd 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -76,10 +76,10 @@ '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*', '*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*'] - NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*'] + NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*', '*_7b_*'] else: EXCLUDE_FILTERS = ['*enormous*'] - NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*'] + NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*', '*_7b_*'] EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*'] From 28278f14c4e3b369434050b7af614926224e9ed8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 16 Sep 2025 10:57:05 -0700 Subject: [PATCH 06/12] Further 7b test filtering. Remove ref_feat_shape from DINOv3 RoPE as it's normalized. Fix torchscript issue with + unary op --- tests/test_models.py | 4 +- timm/layers/pos_embed_sincos.py | 10 ++-- timm/models/eva.py | 91 ++++++++++++++------------------- 3 files changed, 43 insertions(+), 62 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index af8c4087cd..e1275d4cbb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -78,10 +78,10 @@ '*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*'] NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*', '*_7b_*'] else: - EXCLUDE_FILTERS = ['*enormous*'] + EXCLUDE_FILTERS = ['*enormous*', '*_7b_*'] NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*', '*_7b_*'] -EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*'] +EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*', '*_7b_*'] TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index b2cee4fe9b..3a56141a9d 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -901,7 +901,6 @@ def __init__( min_period: Optional[float] = None, max_period: Optional[float] = None, feat_shape: Optional[List[int]] = None, - ref_feat_shape: Optional[List[int]] = None, normalize_coords: str = "separate", # 'min', 'max', 'separate' grid_offset: float = 0.0, grid_indexing: str = "ij", @@ -930,7 +929,6 @@ def __init__( # Grid config self.feat_shape = feat_shape - self.ref_feat_shape = ref_feat_shape self.grid_offset = grid_offset self.grid_indexing = grid_indexing @@ -944,7 +942,7 @@ def __init__( self.register_buffer("pos_embed_cached", None, persistent=False) self.feat_shape = None - def _compute_periods(self, device='cpu', dtype=torch.float32) -> torch.Tensor: + def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = torch.float32) -> torch.Tensor: """Construct periods from either min/max or temperature.""" dim = self.dim // 4 @@ -1016,7 +1014,7 @@ def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor: # Shift per-axis in [-s, +s] if self.shift_coords is not None: shift = float(self.shift_coords) - shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, +shift) + shift_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-shift, shift) coords = coords + shift_hw[None, :] # Jitter: per-axis log-uniform factor in [1/J, J] @@ -1025,7 +1023,7 @@ def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor: if jitter_factor <= 0: raise ValueError("jitter_coords must be > 0 (interpreted as multiplicative factor).") jitter_max = math.log(jitter_factor) - jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, +jitter_max).exp() + jitter_hw = torch.empty(2, device=device, dtype=dtype).uniform_(-jitter_max, jitter_max).exp() coords = coords * jitter_hw[None, :] # Rescale: shared scalar log-uniform factor in [1/R, R] @@ -1034,7 +1032,7 @@ def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor: if rescale_factor <= 0: raise ValueError("rescale_coords must be > 0 (interpreted as multiplicative factor).") rescale_max = math.log(rescale_factor) - rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, +rescale_max).exp() + rescale = torch.empty(1, device=device, dtype=dtype).uniform_(-rescale_max, rescale_max).exp() coords = coords * rescale return coords diff --git a/timm/models/eva.py b/timm/models/eva.py index 2842238c4b..68d5a014fd 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -624,12 +624,7 @@ def __init__( if rope_type == 'mixed': rope_kwargs.update(dict(depth=depth)) self.rope_mixed = True - elif rope_type == 'dinov3': - rope_kwargs.update(dict( - grid_offset=rope_grid_offset, - ref_feat_shape=ref_feat_shape, - )) - else: # 'cat' or 'base' + elif rope_type == 'cat': rope_kwargs.update(dict( in_pixels=False, grid_offset=rope_grid_offset, @@ -1558,74 +1553,62 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: # RoPE-ViT models from Naver 'vit_small_patch16_rope_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_base_patch16_rope_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_large_patch16_rope_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_small_patch16_rope_mixed_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_base_patch16_rope_mixed_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_large_patch16_rope_mixed_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_small_patch16_rope_ape_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_base_patch16_rope_ape_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_large_patch16_rope_ape_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_small_patch16_rope_mixed_ape_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_base_patch16_rope_mixed_ape_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), 'vit_large_patch16_rope_mixed_ape_224.naver_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, license='apache-2.0', ), @@ -1633,85 +1616,85 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: # https://github.com/facebookresearch/dinov3/blob/main/LICENSE.md 'vit_small_patch16_dinov3_224.lvdm_1689m': _cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_small_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_small_plus_patch16_dinov3_224.lvdm_1689m': _cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_small_plus_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_base_patch16_dinov3_224.lvdm_1689m': _cfg( #hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_base_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( #hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_large_patch16_dinov3_224.lvdm_1689m': _cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_large_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_large_patch16_dinov3_224.sat_493m': _cfg( # hf_hub_id='timm/', - mean=(0.430, 0.411, 0.296), - std=(0.213, 0.156, 0.143), + mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143), + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_huge_plus_patch16_dinov3_224.lvdm_1689m': _cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_7b_patch16_dinov3_224.lvdm_1689m': _cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=1.0, num_classes=0, license='dinov3', ), 'vit_7b_patch16_dinov3_224.sat_493m': _cfg( # hf_hub_id='timm/', - mean=(0.430, 0.411, 0.296), - std=(0.213, 0.156, 0.143), + mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143), + crop_pct=1.0, num_classes=0, license='dinov3', ), From 2b1e26637e8ad49600b9113cb8f89742d1c6300d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 16 Sep 2025 11:38:16 -0700 Subject: [PATCH 07/12] Missed 7b filter from GITHUB specific --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index e1275d4cbb..d73d4512fc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -75,7 +75,7 @@ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', '*huge*', '*giant*', '*gigantic*', - '*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*'] + '*enormous*', 'maxvit_xlarge*', 'regnet*1280', 'regnet*2560', '*_1b_*', '*_3b_*', '*_7b_*'] NON_STD_EXCLUDE_FILTERS = ['*huge*', '*giant*', '*gigantic*', '*enormous*', '*_1b_*', '*_3b_*', '*_7b_*'] else: EXCLUDE_FILTERS = ['*enormous*', '*_7b_*'] From 508df64fd96d009560a0c1df897929adc032c55a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 16 Sep 2025 14:20:34 -0700 Subject: [PATCH 08/12] Remove '224' from dinov3 model names and enable dynamic_img_size=True by default. Add assert in DINOv3 ROPE module to avoid torchscript failure (hopefully no more?) --- timm/layers/pos_embed_sincos.py | 1 + timm/models/eva.py | 139 ++++++++++++++------------------ 2 files changed, 62 insertions(+), 78 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 3a56141a9d..60ab9aeb56 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -1088,6 +1088,7 @@ def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor: assert self.feat_shape is not None, 'feature shape must be cached on create' rope_embed = self._create_embed(self.feat_shape) else: + assert self.pos_embed_cached is not None rope_embed = self.pos_embed_cached return rope_embed diff --git a/timm/models/eva.py b/timm/models/eva.py index 68d5a014fd..422a33e485 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -1260,6 +1260,25 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: } +def _dinov3_cfg(url: str = '', **kwargs) -> Dict[str, Any]: + """Generate default configuration for DINOv3 models. + + Args: + url: Model weights URL. + **kwargs: Additional configuration parameters. + + Returns: + Model configuration dictionary. + """ + return { + 'url': url, + 'num_classes': 0, 'input_size': (3, 256, 256), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', 'min_input_size': (3, 128, 128), + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'license': 'dinov3', **kwargs + } + default_cfgs = generate_default_cfgs({ # EVA 01 CLIP fine-tuned on imagenet-1k @@ -1614,89 +1633,43 @@ def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]: # DINOv3 weights are under a specific license with redistribution terms, please see # https://github.com/facebookresearch/dinov3/blob/main/LICENSE.md - 'vit_small_patch16_dinov3_224.lvdm_1689m': _cfg( + 'vit_small_patch16_dinov3.lvdm_1689m': _dinov3_cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_small_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( + 'vit_small_patch16_dinov3_qkvb.lvdm_1689m': _dinov3_cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_small_plus_patch16_dinov3_224.lvdm_1689m': _cfg( + 'vit_small_plus_patch16_dinov3.lvdm_1689m': _dinov3_cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_small_plus_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( + 'vit_small_plus_patch16_dinov3_qkvb.lvdm_1689m': _dinov3_cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_base_patch16_dinov3_224.lvdm_1689m': _cfg( + 'vit_base_patch16_dinov3.lvdm_1689m': _dinov3_cfg( #hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_base_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( + 'vit_base_patch16_dinov3_qkvb.lvdm_1689m': _dinov3_cfg( #hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_large_patch16_dinov3_224.lvdm_1689m': _cfg( + 'vit_large_patch16_dinov3.lvdm_1689m': _dinov3_cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_large_patch16_dinov3_qkvb_224.lvdm_1689m': _cfg( + 'vit_large_patch16_dinov3_qkvb.lvdm_1689m': _dinov3_cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_large_patch16_dinov3_224.sat_493m': _cfg( + 'vit_large_patch16_dinov3.sat_493m': _dinov3_cfg( # hf_hub_id='timm/', mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143), - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_huge_plus_patch16_dinov3_224.lvdm_1689m': _cfg( + 'vit_huge_plus_patch16_dinov3.lvdm_1689m': _dinov3_cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_7b_patch16_dinov3_224.lvdm_1689m': _cfg( + 'vit_7b_patch16_dinov3.lvdm_1689m': _dinov3_cfg( # hf_hub_id='timm/', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=1.0, - num_classes=0, - license='dinov3', ), - 'vit_7b_patch16_dinov3_224.sat_493m': _cfg( + 'vit_7b_patch16_dinov3.sat_493m': _dinov3_cfg( # hf_hub_id='timm/', mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143), - crop_pct=1.0, - num_classes=0, - license='dinov3', ), }) @@ -2640,9 +2613,10 @@ def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> @register_model -def vit_small_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_small_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=384, depth=12, num_heads=6, @@ -2658,14 +2632,15 @@ def vit_small_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: use_fc_norm=False, norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_small_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_small_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_small_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_small_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=384, depth=12, num_heads=6, @@ -2681,14 +2656,15 @@ def vit_small_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva use_fc_norm=False, norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_small_patch16_dinov3_qkvb_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_small_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_small_plus_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_small_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=384, depth=12, num_heads=6, @@ -2706,14 +2682,15 @@ def vit_small_plus_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva use_fc_norm=False, norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_small_plus_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_small_plus_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_small_plus_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_small_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=384, depth=12, num_heads=6, @@ -2731,14 +2708,15 @@ def vit_small_plus_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) - use_fc_norm=False, norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_small_plus_patch16_dinov3_qkvb_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_small_plus_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_base_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_base_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=768, depth=12, num_heads=12, @@ -2754,15 +2732,16 @@ def vit_base_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: use_fc_norm=False, norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_base_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_base_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_base_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_base_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: # DINOv3 Base variant w/ qkv_bias enabled (zero'd in weights) model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=768, depth=12, num_heads=12, @@ -2778,14 +2757,15 @@ def vit_base_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: use_fc_norm=False, norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_base_patch16_dinov3_qkvb_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_base_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_large_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_large_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=1024, depth=24, num_heads=16, @@ -2801,14 +2781,15 @@ def vit_large_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: use_fc_norm=False, norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_large_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_large_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_large_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_large_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=768, depth=24, num_heads=16, @@ -2824,14 +2805,15 @@ def vit_large_patch16_dinov3_qkvb_224(pretrained: bool = False, **kwargs) -> Eva use_fc_norm=False, norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_large_patch16_dinov3_qkvb_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_large_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_huge_plus_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_huge_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=1280, depth=32, num_heads=20, @@ -2850,14 +2832,15 @@ def vit_huge_plus_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_huge_plus_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_huge_plus_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model -def vit_7b_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: +def vit_7b_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: model_args = dict( patch_size=16, + dynamic_img_size=True, embed_dim=4096, depth=40, num_heads=32, @@ -2877,5 +2860,5 @@ def vit_7b_patch16_dinov3_224(pretrained: bool = False, **kwargs) -> Eva: norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_7b_patch16_dinov3_224', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_7b_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs)) return model From 7fbebaabd179092781598d1666d7788b6b039a94 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 16 Sep 2025 16:29:01 -0700 Subject: [PATCH 09/12] Move make_coords in dinov3 RoPE to a free fn so it can be wrapped for fx --- timm/layers/pos_embed_sincos.py | 98 ++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 44 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 60ab9aeb56..fd8bb1416e 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -885,6 +885,54 @@ def no_weight_decay(self): return {'freqs'} +@torch.fx.wrap +@register_notrace_function +def make_coords_dinov3( + height: int, + width: int, + normalize_coords: str = 'separate', + grid_indexing: str = 'ij', + grid_offset: float = 0., + device: torch.device = 'cpu', + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Make coordinate grid matching offset and normalization of original. + Returns: coords with shape (HW, 2) in [-1, 1]. + """ + # 0.5-centered indices with optional offset + coords_h = torch.arange(0.5, height, device=device, dtype=dtype) + grid_offset + coords_w = torch.arange(0.5, width, device=device, dtype=dtype) + grid_offset + + # Normalization denominators + if normalize_coords == "max": + denom = float(max(height, width)) + h_denom = denom + w_denom = denom + elif normalize_coords == "min": + denom = float(min(height, width)) + h_denom = denom + w_denom = denom + elif normalize_coords == "separate": + h_denom = float(height) + w_denom = float(width) + else: + raise ValueError(f"Unknown normalize_coords: {normalize_coords}") + + # Normalize to [0, 1] + coords_h = coords_h / h_denom + coords_w = coords_w / w_denom + + # Create grid then map to [-1, 1] + if grid_indexing == "xy": + grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy") + coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order) + else: + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2) + coords = coords.flatten(0, 1) # (HW, 2) + coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1] + return coords + + class RotaryEmbeddingDinoV3(nn.Module): """RoPE for timm DinoV3 port, numerically matching original. @@ -960,49 +1008,6 @@ def _compute_periods(self, device: torch.device = 'cpu', dtype: torch.dtype = to return periods - def _make_coords( - self, - height: int, - width: int, - device: torch.device = 'cpu', - dtype: torch.dtype = torch.float32, - ) -> torch.Tensor: - """Make coordinate grid matching offset and normalization of original. - Returns: coords with shape (HW, 2) in [-1, 1]. - """ - # 0.5-centered indices with optional offset - coords_h = torch.arange(0.5, height, device=device, dtype=dtype) + self.grid_offset - coords_w = torch.arange(0.5, width, device=device, dtype=dtype) + self.grid_offset - - # Normalization denominators - if self.normalize_coords == "max": - denom = float(max(height, width)) - h_denom = denom - w_denom = denom - elif self.normalize_coords == "min": - denom = float(min(height, width)) - h_denom = denom - w_denom = denom - elif self.normalize_coords == "separate": - h_denom = float(height) - w_denom = float(width) - else: - raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") - - # Normalize to [0, 1] - coords_h = coords_h / h_denom - coords_w = coords_w / w_denom - - # Create grid then map to [-1, 1] - if self.grid_indexing == "xy": - grid_w, grid_h = torch.meshgrid(coords_w, coords_h, indexing="xy") - coords = torch.stack([grid_h, grid_w], dim=-1) # (H, W, 2) -> (h, w order) - else: - coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # (H, W, 2) - coords = coords.flatten(0, 1) # (HW, 2) - coords = 2.0 * coords - 1.0 # (H, W, 2) in [-1, 1] - return coords - def _apply_coord_augs(self, coords: torch.Tensor) -> torch.Tensor: """Apply shift/jitter/rescale train time augmentations.""" if not self.training or not self.aug_active: @@ -1063,7 +1068,12 @@ def _get_pos_embed_from_coords(self, coords: torch.Tensor) -> Tuple[torch.Tenso def _create_embed(self, feat_shape: List[int], no_aug: bool = False) -> torch.Tensor: H, W = feat_shape - coords = self._make_coords(H, W) # (HW, 2) + coords = make_coords_dinov3( + H, W, + normalize_coords=self.normalize_coords, + grid_indexing=self.grid_indexing, + grid_offset=self.grid_offset + ) # (HW, 2) if not no_aug: coords = self._apply_coord_augs(coords) sin, cos = self._get_pos_embed_from_coords(coords) # 2 * (HW, dim) From 9bd8137f802276aea1a019965d1b02017e25ac42 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Sep 2025 11:56:18 -0700 Subject: [PATCH 10/12] Final dinov3 pretrained cfgs, pointing to uploaded weights --- timm/models/eva.py | 91 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 24 deletions(-) diff --git a/timm/models/eva.py b/timm/models/eva.py index 422a33e485..913f8c6f11 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -1633,42 +1633,49 @@ def _dinov3_cfg(url: str = '', **kwargs) -> Dict[str, Any]: # DINOv3 weights are under a specific license with redistribution terms, please see # https://github.com/facebookresearch/dinov3/blob/main/LICENSE.md - 'vit_small_patch16_dinov3.lvdm_1689m': _dinov3_cfg( - # hf_hub_id='timm/', + 'vit_small_patch16_dinov3.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), - 'vit_small_patch16_dinov3_qkvb.lvdm_1689m': _dinov3_cfg( - # hf_hub_id='timm/', + 'vit_small_patch16_dinov3_qkvb.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), - 'vit_small_plus_patch16_dinov3.lvdm_1689m': _dinov3_cfg( - # hf_hub_id='timm/', + 'vit_small_plus_patch16_dinov3.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), - 'vit_small_plus_patch16_dinov3_qkvb.lvdm_1689m': _dinov3_cfg( - # hf_hub_id='timm/', + 'vit_small_plus_patch16_dinov3_qkvb.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), - 'vit_base_patch16_dinov3.lvdm_1689m': _dinov3_cfg( - #hf_hub_id='timm/', + 'vit_base_patch16_dinov3.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), - 'vit_base_patch16_dinov3_qkvb.lvdm_1689m': _dinov3_cfg( - #hf_hub_id='timm/', + 'vit_base_patch16_dinov3_qkvb.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), - 'vit_large_patch16_dinov3.lvdm_1689m': _dinov3_cfg( - # hf_hub_id='timm/', + 'vit_large_patch16_dinov3.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), - 'vit_large_patch16_dinov3_qkvb.lvdm_1689m': _dinov3_cfg( - # hf_hub_id='timm/', + 'vit_large_patch16_dinov3_qkvb.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), 'vit_large_patch16_dinov3.sat_493m': _dinov3_cfg( - # hf_hub_id='timm/', + hf_hub_id='timm/', + mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143), + ), + 'vit_large_patch16_dinov3_qkvb.sat_493m': _dinov3_cfg( + hf_hub_id='timm/', mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143), ), - 'vit_huge_plus_patch16_dinov3.lvdm_1689m': _dinov3_cfg( - # hf_hub_id='timm/', + 'vit_huge_plus_patch16_dinov3.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', + ), + 'vit_huge_plus_patch16_dinov3_qkvb.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), - 'vit_7b_patch16_dinov3.lvdm_1689m': _dinov3_cfg( - # hf_hub_id='timm/', + 'vit_7b_patch16_dinov3.lvd_1689m': _dinov3_cfg( + hf_hub_id='timm/', ), 'vit_7b_patch16_dinov3.sat_493m': _dinov3_cfg( - # hf_hub_id='timm/', + hf_hub_id='timm/', mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143), ), @@ -2614,6 +2621,7 @@ def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> @register_model def vit_small_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 S/16 https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, @@ -2638,6 +2646,7 @@ def vit_small_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: @register_model def vit_small_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 S/16 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, @@ -2662,6 +2671,7 @@ def vit_small_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: @register_model def vit_small_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 S/16 Plus https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, @@ -2688,6 +2698,7 @@ def vit_small_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: @register_model def vit_small_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 S/16 Plus w/ QKV bias enabled (but 0) https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, @@ -2714,6 +2725,7 @@ def vit_small_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Ev @register_model def vit_base_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 B/16 https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, @@ -2738,7 +2750,7 @@ def vit_base_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: @register_model def vit_base_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: - # DINOv3 Base variant w/ qkv_bias enabled (zero'd in weights) + """DINOv3 B/16 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, @@ -2763,6 +2775,7 @@ def vit_base_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: @register_model def vit_large_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 L/16 https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, @@ -2787,10 +2800,11 @@ def vit_large_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: @register_model def vit_large_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, - embed_dim=768, + embed_dim=1024, depth=24, num_heads=16, qkv_bias=True, @@ -2811,6 +2825,7 @@ def vit_large_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: @register_model def vit_huge_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 H/16 Plus https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, @@ -2836,8 +2851,36 @@ def vit_huge_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: return model +@register_model +def vit_huge_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 H/16 Plus w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104""" + model_args = dict( + patch_size=16, + dynamic_img_size=True, + embed_dim=1280, + depth=32, + num_heads=20, + qkv_bias=True, + init_values=1.0e-5, # layer-scale + rope_type='dinov3', + rope_temperature=100, + use_rot_pos_emb=True, + use_abs_pos_emb=False, + rope_rotate_half=True, + swiglu_mlp=True, + swiglu_align_to=8, + #rope_rescale_coords=2, # haven't added to interface + num_reg_tokens=4, + use_fc_norm=False, + norm_layer=partial(LayerNorm, eps=1e-5), + ) + + model = _create_eva('vit_huge_plus_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + @register_model def vit_7b_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva: + """DINOv3 7B/16 https://arxiv.org/abs/2508.10104""" model_args = dict( patch_size=16, dynamic_img_size=True, From 04851a054810bba5148eeb98d4d0626a3204fa24 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Sep 2025 11:56:58 -0700 Subject: [PATCH 11/12] Pass task arg throught to hf hub push --- timm/models/_hub.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 0dbf377e52..c77b04d854 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -404,6 +404,7 @@ def push_to_hf_hub( model_config: Optional[dict] = None, model_card: Optional[dict] = None, model_args: Optional[dict] = None, + task_name: str = 'image-classification', safe_serialization: Union[bool, Literal["both"]] = 'both', ): """ @@ -444,7 +445,7 @@ def push_to_hf_hub( model_card = model_card or {} model_name = repo_id.split('/')[-1] readme_path = Path(tmpdir) / "README.md" - readme_text = generate_readme(model_card, model_name) + readme_text = generate_readme(model_card, model_name, task_name=task_name) readme_path.write_text(readme_text) # Upload model and return From 9a44a236c13aee520d7c8eef74113771ae9c79ea Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Sep 2025 13:10:48 -0700 Subject: [PATCH 12/12] Fix builder variant for dinov3 qkvb h/16 --- timm/models/eva.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/eva.py b/timm/models/eva.py index 913f8c6f11..1cd690b06c 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -2875,7 +2875,7 @@ def vit_huge_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva norm_layer=partial(LayerNorm, eps=1e-5), ) - model = _create_eva('vit_huge_plus_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs)) + model = _create_eva('vit_huge_plus_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs)) return model @register_model