Skip to content

Commit b8cc442

Browse files
committed
Add MobileCLIP2 image encoders
1 parent 818663b commit b8cc442

File tree

3 files changed

+116
-4
lines changed

3 files changed

+116
-4
lines changed

timm/models/fastvit.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,17 @@
1212
import torch
1313
import torch.nn as nn
1414

15-
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16-
from timm.layers import DropPath, trunc_normal_, create_conv2d, ConvNormAct, SqueezeExcite, use_fused_attn, \
17-
ClassifierHead
15+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
16+
from timm.layers import (
17+
DropPath,
18+
trunc_normal_,
19+
create_conv2d,
20+
ConvNormAct,
21+
SqueezeExcite,
22+
use_fused_attn,
23+
ClassifierHead,
24+
LayerNorm2d,
25+
)
1826
from ._builder import build_model_with_cfg
1927
from ._features import feature_take_indices
2028
from ._manipulate import checkpoint_seq
@@ -427,7 +435,8 @@ def convolutional_stem(
427435
in_chs: int,
428436
out_chs: int,
429437
act_layer: Type[nn.Module] = nn.GELU,
430-
inference_mode: bool = False
438+
inference_mode: bool = False,
439+
use_scale_branch: bool = True,
431440
) -> nn.Sequential:
432441
"""Build convolutional stem with MobileOne blocks.
433442
@@ -447,6 +456,7 @@ def convolutional_stem(
447456
stride=2,
448457
act_layer=act_layer,
449458
inference_mode=inference_mode,
459+
use_scale_branch=use_scale_branch,
450460
),
451461
MobileOneBlock(
452462
in_chs=out_chs,
@@ -456,6 +466,7 @@ def convolutional_stem(
456466
group_size=1,
457467
act_layer=act_layer,
458468
inference_mode=inference_mode,
469+
use_scale_branch=use_scale_branch,
459470
),
460471
MobileOneBlock(
461472
in_chs=out_chs,
@@ -464,6 +475,7 @@ def convolutional_stem(
464475
stride=1,
465476
act_layer=act_layer,
466477
inference_mode=inference_mode,
478+
use_scale_branch=use_scale_branch,
467479
),
468480
)
469481

@@ -1118,6 +1130,7 @@ def __init__(
11181130
drop_path_rate: float = 0.0,
11191131
layer_scale_init_value: float = 1e-5,
11201132
lkc_use_act: bool = False,
1133+
stem_use_scale_branch: bool = True,
11211134
fork_feat: bool = False,
11221135
cls_ratio: float = 2.0,
11231136
global_pool: str = 'avg',
@@ -1137,6 +1150,7 @@ def __init__(
11371150
embed_dims[0],
11381151
act_layer,
11391152
inference_mode,
1153+
use_scale_branch=stem_use_scale_branch,
11401154
)
11411155

11421156
# Build the main stages of the network architecture
@@ -1412,6 +1426,35 @@ def _cfg(url="", **kwargs):
14121426
num_classes=512, # CLIP proj dim
14131427
mean=(0., 0., 0.), std=(1., 1., 1.)
14141428
),
1429+
1430+
"fastvit_mci0.apple_mclip2_dfndr2b": _cfg(
1431+
hf_hub_id='timm/',
1432+
crop_pct=1.0,
1433+
num_classes=512, # CLIP proj dim
1434+
mean=(0., 0., 0.), std=(1., 1., 1.),
1435+
license='apple-amlr'
1436+
),
1437+
"fastvit_mci2.apple_mclip2_dfndr2b": _cfg(
1438+
hf_hub_id='timm/',
1439+
crop_pct=0.95,
1440+
num_classes=512, # CLIP proj dim
1441+
mean=(0., 0., 0.), std=(1., 1., 1.),
1442+
license='apple-amlr'
1443+
),
1444+
"fastvit_mci3.apple_mclip2_dfndr2b": _cfg(
1445+
hf_hub_id='timm/',
1446+
crop_pct=0.95,
1447+
num_classes=768, # CLIP proj dim
1448+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1449+
license='apple-amlr'
1450+
),
1451+
"fastvit_mci4.apple_mclip2_dfndr2b": _cfg(
1452+
hf_hub_id='timm/',
1453+
crop_pct=0.95,
1454+
num_classes=768, # CLIP proj dim
1455+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
1456+
license='apple-amlr'
1457+
),
14151458
})
14161459

14171460

@@ -1420,6 +1463,9 @@ def checkpoint_filter_fn(state_dict, model):
14201463
if 'stem.0.conv_kxk.0.conv.weight' in state_dict:
14211464
return state_dict # non-original checkpoint, no remapping needed
14221465

1466+
if 'module.visual.trunk.stem.0.conv_kxk.0.conv.weight' in state_dict:
1467+
return {k.replace('module.visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('module.visual.trunk')}
1468+
14231469
state_dict = state_dict.get('state_dict', state_dict)
14241470
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:
14251471
# remap MobileCLIP checkpoints
@@ -1632,3 +1678,54 @@ def fastvit_mci2(pretrained=False, **kwargs):
16321678
lkc_use_act=True,
16331679
)
16341680
return _create_fastvit('fastvit_mci2', pretrained=pretrained, **dict(model_args, **kwargs))
1681+
1682+
1683+
@register_model
1684+
def fastvit_mci3(pretrained=False, **kwargs):
1685+
"""Instantiate L model variant."""
1686+
model_args = dict(
1687+
layers=(2, 12, 24, 4, 2),
1688+
embed_dims=(96, 192, 384, 768, 1536),
1689+
mlp_ratios=(4, 4, 4, 4, 4),
1690+
se_downsamples=(False, False, False, False, False),
1691+
downsamples=(False, True, True, True, True),
1692+
pos_embs=(
1693+
None,
1694+
None,
1695+
None,
1696+
partial(RepConditionalPosEnc, spatial_shape=(7, 7)),
1697+
partial(RepConditionalPosEnc, spatial_shape=(7, 7))
1698+
),
1699+
token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"),
1700+
lkc_use_act=True,
1701+
norm_layer=partial(LayerNorm2d, eps=1e-5),
1702+
stem_use_scale_branch=False,
1703+
)
1704+
model = _create_fastvit('fastvit_mci3', pretrained=pretrained, **dict(model_args, **kwargs))
1705+
return model
1706+
1707+
1708+
@register_model
1709+
def fastvit_mci4(pretrained=False, **kwargs):
1710+
"""Instantiate XL model variant."""
1711+
model_args = dict(
1712+
layers=(2, 12, 24, 4, 4),
1713+
embed_dims=(128, 256, 512, 1024, 2048),
1714+
mlp_ratios=(4, 4, 4, 4, 4),
1715+
se_downsamples=(False, False, False, False, False),
1716+
downsamples=(False, True, True, True, True),
1717+
pos_embs=(
1718+
None,
1719+
None,
1720+
None,
1721+
partial(RepConditionalPosEnc, spatial_shape=(7, 7)),
1722+
partial(RepConditionalPosEnc, spatial_shape=(7, 7))
1723+
),
1724+
token_mixers=("repmixer", "repmixer", "repmixer", "attention", "attention"),
1725+
lkc_use_act=True,
1726+
norm_layer=partial(LayerNorm2d, eps=1e-5),
1727+
stem_use_scale_branch=False,
1728+
)
1729+
1730+
model = _create_fastvit('fastvit_mci4', pretrained=pretrained, **dict(model_args, **kwargs))
1731+
return model

timm/models/vision_transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,8 @@ def checkpoint_filter_fn(
14161416
# remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
14171417
out_dict['head.weight'] = state_dict['visual.head.proj.weight']
14181418
out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
1419+
elif 'module.visual.trunk.pos_embed' in state_dict:
1420+
prefix = 'module.visual.trunk.'
14191421
elif 'preprocessor.patchifier.proj.weight' in state_dict:
14201422
state_dict = _convert_aimv2(state_dict, model)
14211423

@@ -2007,6 +2009,13 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
20072009
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
20082010
crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),
20092011

2012+
'vit_large_patch14_clip_224.apple_mclip2_dfndr2b': _cfg(
2013+
hf_hub_id='timm/',
2014+
num_classes=768,
2015+
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
2016+
license='apple-amlr'
2017+
),
2018+
20102019
# experimental (may be removed)
20112020
'vit_base_patch32_plus_256.untrained': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
20122021
'vit_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),

timm/models/vision_transformer_hybrid.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,12 @@ def _cfg(url='', **kwargs):
230230
num_classes=512,
231231
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
232232
),
233+
'vit_base_mci_224.apple_mclip2_dfndr2b': _cfg(
234+
hf_hub_id='timm/',
235+
num_classes=512,
236+
mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv',
237+
license='apple-amlr'
238+
),
233239
})
234240

235241

0 commit comments

Comments
 (0)