@@ -451,7 +451,7 @@ def __init__(
451451 dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))]
452452
453453 # build stages
454- stages = nn .ModuleList ()
454+ self . stages = nn .Sequential ()
455455 stride = self .patch_embed .stride
456456 prev_dim = embed_dims [0 ]
457457 self .feature_info = []
@@ -482,9 +482,8 @@ def __init__(
482482 )
483483 prev_dim = out_dim
484484 stride *= 2
485- stages .append (stage )
485+ self . stages .append (stage )
486486 self .feature_info += [dict (num_chs = prev_dim , reduction = stride , module = f'stages.{ stage_idx } ' )]
487- self .stages = nn .Sequential (* stages )
488487
489488 # Classifier head
490489 self .num_features = embed_dims [- 1 ]
@@ -549,22 +548,17 @@ def forward(self, x):
549548
550549
551550def checkpoint_filter_fn (state_dict , model ):
552- # TODO: temporary use for testing, need change after weight convert
553551 if 'model' in state_dict .keys ():
554552 state_dict = state_dict ['model' ]
555553 target_sd = model .state_dict ()
556- target_keys = list (target_sd .keys ())
557554 out_dict = {}
558- i = 0
559555 for k , v in state_dict .items ():
560556 if k .endswith ('attention_bias_idxs' ):
561557 continue
562- tk = target_keys [i ]
563558 if 'attention_biases' in k :
564559 # TODO: whether move this func into model for dynamic input resolution? (high risk)
565- v = resize_rel_pos_bias_table_levit (v .T , target_sd [tk ].shape [::- 1 ]).T
566- out_dict [tk ] = v
567- i += 1
560+ v = resize_rel_pos_bias_table_levit (v .T , target_sd [k ].shape [::- 1 ]).T
561+ out_dict [k ] = v
568562 return out_dict
569563
570564
@@ -585,41 +579,52 @@ def _cfg(url='', **kwargs):
585579
586580default_cfgs = generate_default_cfgs ({
587581 'tiny_vit_5m_224.dist_in22k' : _cfg (
588- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth' ,
582+ hf_hub_id = 'timm/' ,
583+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth',
589584 num_classes = 21841
590585 ),
591586 'tiny_vit_5m_224.dist_in22k_ft_in1k' : _cfg (
592- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth'
587+ hf_hub_id = 'timm/' ,
588+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth'
593589 ),
594590 'tiny_vit_5m_224.in1k' : _cfg (
595- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth'
591+ hf_hub_id = 'timm/' ,
592+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth'
596593 ),
597594 'tiny_vit_11m_224.dist_in22k' : _cfg (
598- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth' ,
595+ hf_hub_id = 'timm/' ,
596+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth',
599597 num_classes = 21841
600598 ),
601599 'tiny_vit_11m_224.dist_in22k_ft_in1k' : _cfg (
602- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth'
600+ hf_hub_id = 'timm/' ,
601+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth'
603602 ),
604603 'tiny_vit_11m_224.in1k' : _cfg (
605- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth'
604+ hf_hub_id = 'timm/' ,
605+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth'
606606 ),
607607 'tiny_vit_21m_224.dist_in22k' : _cfg (
608- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth' ,
608+ hf_hub_id = 'timm/' ,
609+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth',
609610 num_classes = 21841
610611 ),
611612 'tiny_vit_21m_224.dist_in22k_ft_in1k' : _cfg (
612- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth'
613+ hf_hub_id = 'timm/' ,
614+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth'
613615 ),
614616 'tiny_vit_21m_224.in1k' : _cfg (
615- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth'
617+ hf_hub_id = 'timm/' ,
618+ #url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth'
616619 ),
617620 'tiny_vit_21m_384.dist_in22k_ft_in1k' : _cfg (
618- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth' ,
621+ hf_hub_id = 'timm/' ,
622+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth',
619623 input_size = (3 , 384 , 384 ), pool_size = (12 , 12 ), crop_pct = 1.0 ,
620624 ),
621625 'tiny_vit_21m_512.dist_in22k_ft_in1k' : _cfg (
622- url = 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth' ,
626+ hf_hub_id = 'timm/' ,
627+ # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth',
623628 input_size = (3 , 512 , 512 ), pool_size = (16 , 16 ), crop_pct = 1.0 , crop_mode = 'squash' ,
624629 ),
625630})
0 commit comments