12
12
import torch
13
13
import torch .nn as nn
14
14
15
- from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
16
- from timm .layers import DropPath , trunc_normal_ , create_conv2d , ConvNormAct , SqueezeExcite , use_fused_attn , \
17
- ClassifierHead
15
+ from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
16
+ from timm .layers import (
17
+ DropPath ,
18
+ trunc_normal_ ,
19
+ create_conv2d ,
20
+ ConvNormAct ,
21
+ SqueezeExcite ,
22
+ use_fused_attn ,
23
+ ClassifierHead ,
24
+ LayerNorm2d ,
25
+ )
18
26
from ._builder import build_model_with_cfg
19
27
from ._features import feature_take_indices
20
28
from ._manipulate import checkpoint_seq
@@ -427,7 +435,8 @@ def convolutional_stem(
427
435
in_chs : int ,
428
436
out_chs : int ,
429
437
act_layer : Type [nn .Module ] = nn .GELU ,
430
- inference_mode : bool = False
438
+ inference_mode : bool = False ,
439
+ use_scale_branch : bool = True ,
431
440
) -> nn .Sequential :
432
441
"""Build convolutional stem with MobileOne blocks.
433
442
@@ -447,6 +456,7 @@ def convolutional_stem(
447
456
stride = 2 ,
448
457
act_layer = act_layer ,
449
458
inference_mode = inference_mode ,
459
+ use_scale_branch = use_scale_branch ,
450
460
),
451
461
MobileOneBlock (
452
462
in_chs = out_chs ,
@@ -456,6 +466,7 @@ def convolutional_stem(
456
466
group_size = 1 ,
457
467
act_layer = act_layer ,
458
468
inference_mode = inference_mode ,
469
+ use_scale_branch = use_scale_branch ,
459
470
),
460
471
MobileOneBlock (
461
472
in_chs = out_chs ,
@@ -464,6 +475,7 @@ def convolutional_stem(
464
475
stride = 1 ,
465
476
act_layer = act_layer ,
466
477
inference_mode = inference_mode ,
478
+ use_scale_branch = use_scale_branch ,
467
479
),
468
480
)
469
481
@@ -1118,6 +1130,7 @@ def __init__(
1118
1130
drop_path_rate : float = 0.0 ,
1119
1131
layer_scale_init_value : float = 1e-5 ,
1120
1132
lkc_use_act : bool = False ,
1133
+ stem_use_scale_branch : bool = True ,
1121
1134
fork_feat : bool = False ,
1122
1135
cls_ratio : float = 2.0 ,
1123
1136
global_pool : str = 'avg' ,
@@ -1137,6 +1150,7 @@ def __init__(
1137
1150
embed_dims [0 ],
1138
1151
act_layer ,
1139
1152
inference_mode ,
1153
+ use_scale_branch = stem_use_scale_branch ,
1140
1154
)
1141
1155
1142
1156
# Build the main stages of the network architecture
@@ -1412,6 +1426,35 @@ def _cfg(url="", **kwargs):
1412
1426
num_classes = 512 , # CLIP proj dim
1413
1427
mean = (0. , 0. , 0. ), std = (1. , 1. , 1. )
1414
1428
),
1429
+
1430
+ "fastvit_mci0.apple_mclip2_dfndr2b" : _cfg (
1431
+ hf_hub_id = 'timm/' ,
1432
+ crop_pct = 1.0 ,
1433
+ num_classes = 512 , # CLIP proj dim
1434
+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
1435
+ license = 'apple-amlr'
1436
+ ),
1437
+ "fastvit_mci2.apple_mclip2_dfndr2b" : _cfg (
1438
+ hf_hub_id = 'timm/' ,
1439
+ crop_pct = 0.95 ,
1440
+ num_classes = 512 , # CLIP proj dim
1441
+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
1442
+ license = 'apple-amlr'
1443
+ ),
1444
+ "fastvit_mci3.apple_mclip2_dfndr2b" : _cfg (
1445
+ hf_hub_id = 'timm/' ,
1446
+ crop_pct = 0.95 ,
1447
+ num_classes = 768 , # CLIP proj dim
1448
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
1449
+ license = 'apple-amlr'
1450
+ ),
1451
+ "fastvit_mci4.apple_mclip2_dfndr2b" : _cfg (
1452
+ hf_hub_id = 'timm/' ,
1453
+ crop_pct = 0.95 ,
1454
+ num_classes = 768 , # CLIP proj dim
1455
+ mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD ,
1456
+ license = 'apple-amlr'
1457
+ ),
1415
1458
})
1416
1459
1417
1460
@@ -1420,6 +1463,9 @@ def checkpoint_filter_fn(state_dict, model):
1420
1463
if 'stem.0.conv_kxk.0.conv.weight' in state_dict :
1421
1464
return state_dict # non-original checkpoint, no remapping needed
1422
1465
1466
+ if 'module.visual.trunk.stem.0.conv_kxk.0.conv.weight' in state_dict :
1467
+ return {k .replace ('module.visual.trunk.' , '' ): v for k , v in state_dict .items () if k .startswith ('module.visual.trunk' )}
1468
+
1423
1469
state_dict = state_dict .get ('state_dict' , state_dict )
1424
1470
if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict :
1425
1471
# remap MobileCLIP checkpoints
@@ -1632,3 +1678,54 @@ def fastvit_mci2(pretrained=False, **kwargs):
1632
1678
lkc_use_act = True ,
1633
1679
)
1634
1680
return _create_fastvit ('fastvit_mci2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1681
+
1682
+
1683
+ @register_model
1684
+ def fastvit_mci3 (pretrained = False , ** kwargs ):
1685
+ """Instantiate L model variant."""
1686
+ model_args = dict (
1687
+ layers = (2 , 12 , 24 , 4 , 2 ),
1688
+ embed_dims = (96 , 192 , 384 , 768 , 1536 ),
1689
+ mlp_ratios = (4 , 4 , 4 , 4 , 4 ),
1690
+ se_downsamples = (False , False , False , False , False ),
1691
+ downsamples = (False , True , True , True , True ),
1692
+ pos_embs = (
1693
+ None ,
1694
+ None ,
1695
+ None ,
1696
+ partial (RepConditionalPosEnc , spatial_shape = (7 , 7 )),
1697
+ partial (RepConditionalPosEnc , spatial_shape = (7 , 7 ))
1698
+ ),
1699
+ token_mixers = ("repmixer" , "repmixer" , "repmixer" , "attention" , "attention" ),
1700
+ lkc_use_act = True ,
1701
+ norm_layer = partial (LayerNorm2d , eps = 1e-5 ),
1702
+ stem_use_scale_branch = False ,
1703
+ )
1704
+ model = _create_fastvit ('fastvit_mci3' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1705
+ return model
1706
+
1707
+
1708
+ @register_model
1709
+ def fastvit_mci4 (pretrained = False , ** kwargs ):
1710
+ """Instantiate XL model variant."""
1711
+ model_args = dict (
1712
+ layers = (2 , 12 , 24 , 4 , 4 ),
1713
+ embed_dims = (128 , 256 , 512 , 1024 , 2048 ),
1714
+ mlp_ratios = (4 , 4 , 4 , 4 , 4 ),
1715
+ se_downsamples = (False , False , False , False , False ),
1716
+ downsamples = (False , True , True , True , True ),
1717
+ pos_embs = (
1718
+ None ,
1719
+ None ,
1720
+ None ,
1721
+ partial (RepConditionalPosEnc , spatial_shape = (7 , 7 )),
1722
+ partial (RepConditionalPosEnc , spatial_shape = (7 , 7 ))
1723
+ ),
1724
+ token_mixers = ("repmixer" , "repmixer" , "repmixer" , "attention" , "attention" ),
1725
+ lkc_use_act = True ,
1726
+ norm_layer = partial (LayerNorm2d , eps = 1e-5 ),
1727
+ stem_use_scale_branch = False ,
1728
+ )
1729
+
1730
+ model = _create_fastvit ('fastvit_mci4' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1731
+ return model
0 commit comments