Skip to content

Commit 4182170

Browse files
committed
Add frozen_stages kwarg to InternImage backbone (#283)
1 parent fe6cdd2 commit 4182170

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ segmentation/convertor/
66
checkpoint_dir/
77
demo/
88
pretrained/
9-
upload.py
9+
upload.py

detection/mmdet_custom/models/backbones/intern_image.py

+16
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ def __init__(self,
575575
center_feature_scale=False, # for InternImage-H/G
576576
use_dcn_v4_op=False,
577577
out_indices=(0, 1, 2, 3),
578+
frozen_stages=-1,
578579
init_cfg=None,
579580
**kwargs):
580581
super().__init__()
@@ -588,6 +589,8 @@ def __init__(self,
588589
self.init_cfg = init_cfg
589590
self.out_indices = out_indices
590591
self.level2_post_norm_block_ids = level2_post_norm_block_ids
592+
self.frozen_stages = frozen_stages
593+
591594
logger = get_root_logger()
592595
logger.info(f'using core type: {core_op}')
593596
logger.info(f'using activation layer: {act_layer}')
@@ -642,6 +645,19 @@ def __init__(self,
642645
self.num_layers = len(depths)
643646
self.apply(self._init_weights)
644647
self.apply(self._init_deform_weights)
648+
self._freeze_stages()
649+
650+
def train(self, mode=True):
651+
"""Convert the model into training mode while keep normalization layer frozen."""
652+
super(InternImage, self).train(mode)
653+
self._freeze_stages()
654+
655+
def _freeze_stages(self):
656+
if self.frozen_stages >= 0:
657+
for level in self.levels[:self.frozen_stages]:
658+
level.eval()
659+
for param in level.parameters():
660+
param.requires_grad = False
645661

646662
def init_weights(self):
647663
logger = get_root_logger()

segmentation/mmseg_custom/models/backbones/intern_image.py

+16
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,7 @@ def __init__(self,
575575
center_feature_scale=False, # for InternImage-H/G
576576
use_dcn_v4_op=False,
577577
out_indices=(0, 1, 2, 3),
578+
frozen_stages=-1,
578579
init_cfg=None,
579580
**kwargs):
580581
super().__init__()
@@ -588,6 +589,8 @@ def __init__(self,
588589
self.init_cfg = init_cfg
589590
self.out_indices = out_indices
590591
self.level2_post_norm_block_ids = level2_post_norm_block_ids
592+
self.frozen_stages = frozen_stages
593+
591594
logger = get_root_logger()
592595
logger.info(f'using core type: {core_op}')
593596
logger.info(f'using activation layer: {act_layer}')
@@ -642,6 +645,19 @@ def __init__(self,
642645
self.num_layers = len(depths)
643646
self.apply(self._init_weights)
644647
self.apply(self._init_deform_weights)
648+
self._freeze_stages()
649+
650+
def train(self, mode=True):
651+
"""Convert the model into training mode while keep normalization layer frozen."""
652+
super(InternImage, self).train(mode)
653+
self._freeze_stages()
654+
655+
def _freeze_stages(self):
656+
if self.frozen_stages >= 0:
657+
for level in self.levels[:self.frozen_stages]:
658+
level.eval()
659+
for param in level.parameters():
660+
param.requires_grad = False
645661

646662
def init_weights(self):
647663
logger = get_root_logger()

0 commit comments

Comments
 (0)