Skip to content

Commit 77a91b1

Browse files
committedJan 29, 2021
feat: support multi-teacher kd
Summary: support multi-teacher kd with logits and overhaul distillation
1 parent db8670d commit 77a91b1

File tree

7 files changed

+118
-55
lines changed

7 files changed

+118
-55
lines changed
 

‎.gitignore

+35-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,37 @@
1-
.idea
1+
2+
logs
3+
4+
# compilation and distribution
25
__pycache__
3-
.DS_Store
4-
.vscode
6+
_ext
7+
*.pyc
8+
*.pyd
59
*.so
6-
logs/
7-
.ipynb_checkpoints
8-
logs
10+
*.dll
11+
*.egg-info/
12+
build/
13+
dist/
14+
wheels/
15+
16+
# pytorch/python/numpy formats
17+
*.pth
18+
*.pkl
19+
*.npy
20+
*.ts
21+
model_ts*.txt
22+
23+
# ipython/jupyter notebooks
24+
*.ipynb
25+
**/.ipynb_checkpoints/
26+
27+
# Editor temporaries
28+
*.swn
29+
*.swo
30+
*.swp
31+
*~
32+
33+
# editor settings
34+
.idea
35+
.vscode
36+
_darcs
37+
.DS_Store

‎CHANGELOG.md

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Changelog
2+
3+
### v1.1 (29/01/2021)
4+
5+
#### New Features
6+
7+
- NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20)
8+
- Multi-teacher Knowledge Distillation
9+
10+
#### Bug Fixes
11+
12+
#### Improvements

‎README.md

+17-13
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,52 @@ FastReID is a research platform that implements state-of-the-art re-identificati
44

55
## What's New
66

7-
- [Jan 2021] NAIC20(reid track) [1-st solution](https://github.com/JDAI-CV/fast-reid/tree/master/projects/NAIC20) based on fastreid has been released!
7+
- [Jan 2021] NAIC20(reid track) [1-st solution](projects/NAIC20) based on fastreid has been released!
88
- [Jan 2021] FastReID V1.0 has been released!🎉
99
Support many tasks beyond reid, such image retrieval and face recognition. See [release notes](https://github.com/JDAI-CV/fast-reid/releases/tag/v1.0.0).
10-
- [Oct 2020] Added the [Hyper-Parameter Optimization](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastTune) based on fastreid. See `projects/FastTune`.
11-
- [Sep 2020] Added the [person attribute recognition](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastAttr) based on fastreid. See `projects/FastAttr`.
10+
- [Oct 2020] Added the [Hyper-Parameter Optimization](projects/FastTune) based on fastreid. See `projects/FastTune`.
11+
- [Sep 2020] Added the [person attribute recognition](projects/FastAttr) based on fastreid. See `projects/FastAttr`.
1212
- [Sep 2020] Automatic Mixed Precision training is supported with `apex`. Set `cfg.SOLVER.FP16_ENABLED=True` to switch it on.
13-
- [Aug 2020] [Model Distillation](https://github.com/JDAI-CV/fast-reid/tree/master/projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
13+
- [Aug 2020] [Model Distillation](projects/FastDistill) is supported, thanks for [guan'an wang](https://github.com/wangguanan)'s contribution.
1414
- [Aug 2020] ONNX/TensorRT converter is supported.
1515
- [Jul 2020] Distributed training with multiple GPUs, it trains much faster.
1616
- Includes more features such as circle loss, abundant visualization methods and evaluation metrics, SoTA results on conventional, cross-domain, partial and vehicle re-id, testing on multi-datasets simultaneously, etc.
17-
- Can be used as a library to support [different projects](https://github.com/JDAI-CV/fast-reid/tree/master/projects) on top of it. We'll open source more research projects in this way.
17+
- Can be used as a library to support [different projects](projects) on top of it. We'll open source more research projects in this way.
1818
- Remove [ignite](https://github.com/pytorch/ignite)(a high-level library) dependency and powered by [PyTorch](https://pytorch.org/).
1919

2020
We write a [chinese blog](https://l1aoxingyu.github.io/blogpages/reid/2020/05/29/fastreid.html) about this toolbox.
2121

22+
## Changelog
23+
24+
Please refer to [changelog.md](CHANGELOG.md) for details and release history.
25+
2226
## Installation
2327

24-
See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/INSTALL.md).
28+
See [INSTALL.md](INSTALL.md).
2529

2630
## Quick Start
2731

2832
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
2933

30-
See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md).
34+
See [GETTING_STARTED.md](GETTING_STARTED.md).
3135

32-
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
36+
Learn more at out [documentation](https://fast-reid.readthedocs.io/). And see [projects/](projects) for some projects that are build on top of fastreid.
3337

3438
## Model Zoo and Baselines
3539

36-
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md).
40+
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](MODEL_ZOO.md).
3741

3842
## Deployment
3943

40-
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](https://github.com/JDAI-CV/fast-reid/blob/master/tools/deploy).
44+
We provide some examples and scripts to convert fastreid model to Caffe, ONNX and TensorRT format in [Fastreid deploy](tools/deploy).
4145

4246
## License
4347

44-
Fastreid is released under the [Apache 2.0 license](https://github.com/JDAI-CV/fast-reid/blob/master/LICENSE).
48+
Fastreid is released under the [Apache 2.0 license](LICENSE).
4549

46-
## Citing Fastreid
50+
## Citing FastReID
4751

48-
If you use Fastreid in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
52+
If you use FastReID in your research or wish to refer to the baseline results published in the Model Zoo, please use the following BibTeX entry.
4953

5054
```BibTeX
5155
@article{he2020fastreid,

‎fastreid/config/defaults.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@
128128
# -----------------------------------------------------------------------------
129129

130130
_C.KD = CN()
131-
_C.KD.MODEL_CONFIG = ""
132-
_C.KD.MODEL_WEIGHTS = ""
131+
_C.KD.MODEL_CONFIG = ['',]
132+
_C.KD.MODEL_WEIGHTS = ['',]
133133

134134
# -----------------------------------------------------------------------------
135135
# INPUT

‎fastreid/modeling/meta_arch/distiller.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,25 @@ def __init__(self, cfg):
2222
super(Distiller, self).__init__(cfg)
2323

2424
# Get teacher model config
25-
cfg_t = get_cfg()
26-
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG)
25+
model_ts = []
26+
for i in range(len(cfg.KD.MODEL_CONFIG)):
27+
cfg_t = get_cfg()
28+
cfg_t.merge_from_file(cfg.KD.MODEL_CONFIG[i])
2729

28-
model_t = build_model(cfg_t)
29-
logger.info("Teacher model:\n{}".format(model_t))
30+
model_t = build_model(cfg_t)
3031

31-
# No gradients for teacher model
32-
for param in model_t.parameters():
33-
param.requires_grad_(False)
32+
# No gradients for teacher model
33+
for param in model_t.parameters():
34+
param.requires_grad_(False)
3435

35-
logger.info("Loading teacher model weights ...")
36-
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS)
36+
logger.info("Loading teacher model weights ...")
37+
Checkpointer(model_t).load(cfg.KD.MODEL_WEIGHTS[i])
38+
39+
model_ts.append(model_t)
3740

3841
# Not register teacher model as `nn.Module`, this is
3942
# make sure teacher model weights not saved
40-
self.model_t = [model_t.backbone, model_t.heads]
43+
self.model_ts = model_ts
4144

4245
def forward(self, batched_inputs):
4346
if self.training:
@@ -51,10 +54,13 @@ def forward(self, batched_inputs):
5154

5255
s_outputs = self.heads(s_feat, targets)
5356

57+
t_outputs = []
5458
# teacher model forward
5559
with torch.no_grad():
56-
t_feat = self.model_t[0](images)
57-
t_outputs = self.model_t[1](t_feat, targets)
60+
for model_t in self.model_ts:
61+
t_feat = model_t.backbone(images)
62+
t_output = model_t.heads(t_feat, targets)
63+
t_outputs.append(t_output)
5864

5965
losses = self.losses(s_outputs, t_outputs, targets)
6066
return losses
@@ -71,8 +77,12 @@ def losses(self, s_outputs, t_outputs, gt_labels):
7177
loss_dict = super(Distiller, self).losses(s_outputs, gt_labels)
7278

7379
s_logits = s_outputs["pred_class_logits"]
74-
t_logits = t_outputs["pred_class_logits"].detach()
75-
loss_dict["loss_jsdiv"] = self.jsdiv_loss(s_logits, t_logits)
80+
loss_jsdiv = 0.
81+
for t_output in t_outputs:
82+
t_logits = t_output["pred_class_logits"].detach()
83+
loss_jsdiv += self.jsdiv_loss(s_logits, t_logits)
84+
85+
loss_dict["loss_jsdiv"] = loss_jsdiv / len(t_outputs)
7686

7787
return loss_dict
7888

‎projects/FastDistill/configs/kd-sbs_r101ibn-sbs_r34.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ MODEL:
88
WITH_IBN: False
99

1010
KD:
11-
MODEL_CONFIG: projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml
12-
MODEL_WEIGHTS: projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth
11+
MODEL_CONFIG: ("projects/FastDistill/logs/dukemtmc/r101_ibn/config.yaml",)
12+
MODEL_WEIGHTS: ("projects/FastDistill/logs/dukemtmc/r101_ibn/model_best.pth",)
1313

1414
DATASETS:
1515
NAMES: ("DukeMTMC",)

‎projects/FastDistill/fastdistill/overhaul.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,18 @@ def __init__(self, cfg):
6161
super().__init__(cfg)
6262

6363
s_channels = self.backbone.get_channel_nums()
64-
t_channels = self.model_t[0].get_channel_nums()
6564

66-
self.connectors = nn.ModuleList(
67-
[build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)])
65+
for i in range(len(self.model_ts)):
66+
t_channels = self.model_ts[i].backbone.get_channel_nums()
6867

69-
teacher_bns = self.model_t[0].get_bn_before_relu()
70-
margins = [get_margin_from_BN(bn) for bn in teacher_bns]
71-
for i, margin in enumerate(margins):
72-
self.register_buffer("margin%d" % (i + 1),
73-
margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
68+
setattr(self, "connectors_{}".format(i), nn.ModuleList(
69+
[build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)]))
70+
71+
teacher_bns = self.model_ts[i].backbone.get_bn_before_relu()
72+
margins = [get_margin_from_BN(bn) for bn in teacher_bns]
73+
for j, margin in enumerate(margins):
74+
self.register_buffer("margin{}_{}".format(i, j + 1),
75+
margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach())
7476

7577
def forward(self, batched_inputs):
7678
if self.training:
@@ -84,20 +86,25 @@ def forward(self, batched_inputs):
8486

8587
s_outputs = self.heads(s_feat, targets)
8688

89+
t_feats_list = []
90+
t_outputs = []
8791
# teacher model forward
8892
with torch.no_grad():
89-
t_feats, t_feat = self.model_t[0].extract_feature(images, preReLU=True)
90-
t_outputs = self.model_t[1](t_feat, targets)
93+
for model_t in self.model_ts:
94+
t_feats, t_feat = model_t.backbone.extract_feature(images, preReLU=True)
95+
t_output = model_t.heads(t_feat, targets)
96+
t_feats_list.append(t_feats)
97+
t_outputs.append(t_output)
9198

92-
losses = self.losses(s_outputs, s_feats, t_outputs, t_feats, targets)
99+
losses = self.losses(s_outputs, s_feats, t_outputs, t_feats_list, targets)
93100
return losses
94101

95102
else:
96103
outputs = super(DistillerOverhaul, self).forward(batched_inputs)
97104
return outputs
98105

99-
def losses(self, s_outputs, s_feats, t_outputs, t_feats, gt_labels):
100-
r"""
106+
def losses(self, s_outputs, s_feats, t_outputs, t_feats_list, gt_labels):
107+
"""
101108
Compute loss from modeling's outputs, the loss function input arguments
102109
must be the same as the outputs of the model forwarding.
103110
"""
@@ -106,11 +113,12 @@ def losses(self, s_outputs, s_feats, t_outputs, t_feats, gt_labels):
106113
# Overhaul distillation loss
107114
feat_num = len(s_feats)
108115
loss_distill = 0
109-
for i in range(feat_num):
110-
s_feats[i] = self.connectors[i](s_feats[i])
111-
loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(
112-
self, "margin%d" % (i + 1)).to(s_feats[i].dtype)) / 2 ** (feat_num - i - 1)
116+
for i in range(len(t_feats_list)):
117+
for j in range(feat_num):
118+
s_feats_connect = getattr(self, "connectors_{}".format(i))[j](s_feats[j])
119+
loss_distill += distillation_loss(s_feats_connect, t_feats_list[i][j].detach(), getattr(
120+
self, "margin{}_{}".format(i, j + 1)).to(s_feats_connect.dtype)) / 2 ** (feat_num - j - 1)
113121

114-
loss_dict["loss_overhaul"] = loss_distill / len(gt_labels) / 10000
122+
loss_dict["loss_overhaul"] = loss_distill / len(t_feats_list) / len(gt_labels) / 10000
115123

116124
return loss_dict

0 commit comments

Comments
 (0)
Please sign in to comment.