Skip to content

Commit 2f0fbb5

Browse files
committed
TinyViT weights on HF hub
1 parent 507cb08 commit 2f0fbb5

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

timm/models/tiny_vit.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def __init__(
451451
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
452452

453453
# build stages
454-
stages = nn.ModuleList()
454+
self.stages = nn.Sequential()
455455
stride = self.patch_embed.stride
456456
prev_dim = embed_dims[0]
457457
self.feature_info = []
@@ -482,9 +482,8 @@ def __init__(
482482
)
483483
prev_dim = out_dim
484484
stride *= 2
485-
stages.append(stage)
485+
self.stages.append(stage)
486486
self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')]
487-
self.stages = nn.Sequential(*stages)
488487

489488
# Classifier head
490489
self.num_features = embed_dims[-1]
@@ -549,22 +548,17 @@ def forward(self, x):
549548

550549

551550
def checkpoint_filter_fn(state_dict, model):
552-
# TODO: temporary use for testing, need change after weight convert
553551
if 'model' in state_dict.keys():
554552
state_dict = state_dict['model']
555553
target_sd = model.state_dict()
556-
target_keys = list(target_sd.keys())
557554
out_dict = {}
558-
i = 0
559555
for k, v in state_dict.items():
560556
if k.endswith('attention_bias_idxs'):
561557
continue
562-
tk = target_keys[i]
563558
if 'attention_biases' in k:
564559
# TODO: whether move this func into model for dynamic input resolution? (high risk)
565-
v = resize_rel_pos_bias_table_levit(v.T, target_sd[tk].shape[::-1]).T
566-
out_dict[tk] = v
567-
i += 1
560+
v = resize_rel_pos_bias_table_levit(v.T, target_sd[k].shape[::-1]).T
561+
out_dict[k] = v
568562
return out_dict
569563

570564

@@ -585,41 +579,52 @@ def _cfg(url='', **kwargs):
585579

586580
default_cfgs = generate_default_cfgs({
587581
'tiny_vit_5m_224.dist_in22k': _cfg(
588-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth',
582+
hf_hub_id='timm/',
583+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth',
589584
num_classes=21841
590585
),
591586
'tiny_vit_5m_224.dist_in22k_ft_in1k': _cfg(
592-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth'
587+
hf_hub_id='timm/',
588+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth'
593589
),
594590
'tiny_vit_5m_224.in1k': _cfg(
595-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth'
591+
hf_hub_id='timm/',
592+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth'
596593
),
597594
'tiny_vit_11m_224.dist_in22k': _cfg(
598-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth',
595+
hf_hub_id='timm/',
596+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth',
599597
num_classes=21841
600598
),
601599
'tiny_vit_11m_224.dist_in22k_ft_in1k': _cfg(
602-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth'
600+
hf_hub_id='timm/',
601+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth'
603602
),
604603
'tiny_vit_11m_224.in1k': _cfg(
605-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth'
604+
hf_hub_id='timm/',
605+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth'
606606
),
607607
'tiny_vit_21m_224.dist_in22k': _cfg(
608-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth',
608+
hf_hub_id='timm/',
609+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth',
609610
num_classes=21841
610611
),
611612
'tiny_vit_21m_224.dist_in22k_ft_in1k': _cfg(
612-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth'
613+
hf_hub_id='timm/',
614+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth'
613615
),
614616
'tiny_vit_21m_224.in1k': _cfg(
615-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth'
617+
hf_hub_id='timm/',
618+
#url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth'
616619
),
617620
'tiny_vit_21m_384.dist_in22k_ft_in1k': _cfg(
618-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth',
621+
hf_hub_id='timm/',
622+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth',
619623
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
620624
),
621625
'tiny_vit_21m_512.dist_in22k_ft_in1k': _cfg(
622-
url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth',
626+
hf_hub_id='timm/',
627+
# url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth',
623628
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash',
624629
),
625630
})

0 commit comments

Comments
 (0)