Skip to content

Commit 99aa8f5

Browse files
author
greatlog
authored
feat(keypoints): Keypoints support 1.2 (#60)
1 parent 7036b00 commit 99aa8f5

File tree

14 files changed

+311
-539
lines changed

14 files changed

+311
-539
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ wheels/
3131

3232
# vscode editor settings
3333
.vscode
34+

hubconf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
from official.vision.detection.tools.utils import DetEvaluator
4848
from official.vision.keypoints.inference import KeypointEvaluator
4949
from official.vision.keypoints.models import (
50-
mspn_4stage,
5150
simplebaseline_res50,
5251
simplebaseline_res101,
5352
simplebaseline_res152,

official/vision/keypoints/README.md

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# Human Pose Esimation
22

3-
本目录包含了采用MegEngine实现的经典[SimpleBaseline](https://arxiv.org/pdf/1804.06208.pdf)[MSPN](https://arxiv.org/pdf/1901.00148.pdf)网络结构,同时提供了在COCO数据集上的完整训练和测试代码。
3+
本目录包含了采用MegEngine实现的经典[SimpleBaseline](https://arxiv.org/pdf/1804.06208.pdf)的网络结构,同时提供了在COCO数据集上的完整训练和测试代码。
44

55
本目录使用了在COCO val2017上的Human AP为56.4的人体检测结果,最后在COCO val2017上人体关节点估计结果为
66
|Methods|Backbone|Input Size| AP | Ap .5 | AP .75 | AP (M) | AP (L) | AR | AR .5 | AR .75 | AR (M) | AR (L) |
77
|---|:---:|---|---|---|---|---|---|---|---|---|---|---|
8-
| SimpleBaseline |Res50 |256x192| 0.712 | 0.887 | 0.779 | 0.673 | 0.785 | 0.782 | 0.932 | 0.839 | 0.730 | 0.854 |
9-
| SimpleBaseline |Res101|256x192| 0.722 | 0.891 | 0.795 | 0.687 | 0.795 | 0.794 | 0.936 | 0.855 | 0.745 | 0.863 |
10-
| SimpleBaseline |Res152|256x192| 0.724 | 0.888 | 0.794 | 0.688 | 0.795 | 0.795 | 0.934 | 0.856 | 0.746 | 0.863 |
11-
| MSPN_4stage |MSPN|256x192| 0.752 | 0.900 | 0.819 | 0.716 | 0.825 | 0.819 | 0.943 | 0.875 | 0.770 | 0.887 |
8+
| SimpleBaseline |Res50 |256x192| 0.711 | 0.885 | 0.779 | 0.674 | 0.783 | 0.782 | 0.930 | 0.839 | 0.731 | 0.852 |
9+
| SimpleBaseline |Res101|256x192| 0.718 | 0.892 | 0.788 | 0.681 | 0.793 | 0.790 | 0.937 | 0.848 | 0.739 | 0.861 |
10+
| SimpleBaseline |Res152|256x192| 0.723 | 0.888 | 0.794 | 0.688 | 0.795 | 0.795 | 0.934 | 0.856 | 0.746 | 0.863 |
1211

1312
## 安装和环境配置
1413

@@ -65,16 +64,7 @@ ${COCO_DATA_ROOT}
6564
python3 train.py --arch simplebaseline_res50 \
6665
--resume /path/to/model \
6766
--ngpus 8 \
68-
--multi_scale_supervision False
6967

70-
```
71-
训练MSPN:
72-
```bash
73-
python3 train.py --arch mspn_4stage \
74-
--resume /path/to/model \
75-
--ngpus 8 \
76-
--multi_scale_supervision True
77-
7868
```
7969

8070
## 如何测试
@@ -83,12 +73,10 @@ python3 train.py --arch mspn_4stage \
8373
```bash
8474
python3 test.py --arch name/of/network \
8575
--model /path/to/model.pkl \
86-
--dt_file /name/human/detection/results
8776
```
8877
`test.py`的命令行参数如下:
8978
- `--arch`, 网络的名字;
90-
- `--model`, 待检测的模;
91-
- `--dt_path`,人体检测结果.
79+
- `--model`, 待检测的模型;
9280

9381
也可以连续验证多个模型的性能:
9482

@@ -119,5 +107,4 @@ python3 inference.py --arch /name/of/tested/network \
119107

120108
## 参考文献
121109

122-
- [Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/pdf/1804.06208.pdf) Bin Xiao, Haiping Wu, and Yichen Wei
123-
- [Rethinking on Multi-Stage Networks for Human Pose Estimation](https://arxiv.org/pdf/1901.00148.pdf) Wenbo Li1, Zhicheng Wang, Binyi Yin, Qixiang Peng, Yuming Du, Tianzi Xiao, Gang Yu, Hongtao Lu, Yichen Wei and Jian Sun
110+
- [Simple Baselines for Human Pose Estimation and Tracking](https://arxiv.org/abs/1804.06208) Bin Xiao, Haiping Wu, and Yichen Wei. European Conference on Computer Vision (ECCV), 2018.

official/vision/keypoints/config.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,47 @@
66
# Unless required by applicable law or agreed to in writing,
77
# software distributed under the License is distributed on an
88
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
10+
911
class Config:
10-
##############3## train ##############################################
12+
# model
13+
model_choices = [
14+
"simplebaseline_res50",
15+
"simplebaseline_res101",
16+
"simplebaseline_res152",
17+
]
18+
19+
# train
1120
initial_lr = 3e-4
1221
lr_ratio = 0.1
1322

1423
batch_size = 32
1524
epochs = 200
16-
warm_epochs = 1
17-
weight_decay = 1e-5
25+
warm_epochs = 0
26+
weight_decay = 0
27+
28+
report_freq = 10
29+
save_freq = 1
1830

19-
################## data ###############################################
31+
# data
2032
# path
2133
data_root = "/data/coco_data/"
2234

2335
# normalize
24-
img_mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
25-
img_std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
36+
img_mean = [103.530, 116.280, 123.675]
37+
img_std = [57.375, 57.120, 58.395]
2638

2739
# shape
2840
input_shape = (256, 192)
2941
output_shape = (64, 48)
3042

3143
# heat maps
3244
keypoint_num = 17
33-
heat_kernel = [2.6, 2.0, 1.7, 1.4]
45+
heat_kernels = [k * 4 for k in [2.6, 2.0, 1.7, 1.4]]
3446
heat_thr = 1e-2
3547
heat_range = 255
3648

37-
##################### augumentation #####################################
38-
49+
# augmentation
3950
half_body_transform = True
4051
extend_boxes = True
4152

@@ -53,19 +64,21 @@ class Config:
5364

5465
# scale
5566
scale_prob = 1
56-
scale_range = [0.7, 1.3]
67+
scale_range = 0.3
5768

5869
# rorate
5970
rotation_prob = 0.6
60-
rotate_range = [-45, 45]
71+
rotate_range = 40
6172

62-
############## testing settings ##########################################
73+
# test settings
6374
test_aug_border = 10
6475
test_x_ext = 0.10
6576
test_y_ext = 0.10
6677
test_gaussian_kernel = 17
6778
second_value_aug = True
6879

80+
# inference settings
81+
nms_thr = 0.7
6982
vis_colors = [
7083
[255, 0, 0],
7184
[255, 85, 0],
@@ -95,6 +108,8 @@ class Config:
95108
[0, 2],
96109
[1, 3],
97110
[2, 4],
111+
[3, 5],
112+
[4, 6],
98113
[5, 6],
99114
[5, 7],
100115
[7, 9],

official/vision/keypoints/dataset.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@
66
# Unless required by applicable law or agreed to in writing,
77
# software distributed under the License is distributed on an
88
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9-
import megengine as mge
10-
from megengine.data.dataset.vision.meta_vision import VisionDataset
11-
from megengine.data import Collator
9+
import json
10+
import os.path as osp
11+
from collections import OrderedDict, defaultdict
1212

13-
import numpy as np
1413
import cv2
15-
import os.path as osp
16-
import json
17-
from collections import defaultdict, OrderedDict
14+
import numpy as np
15+
16+
from megengine.data import Collator
17+
from megengine.data.dataset.vision.meta_vision import VisionDataset
1818

1919

2020
class COCOJoints(VisionDataset):
2121
"""
2222
we cannot use the official implementation of COCO dataset here.
23-
The output of __getitem__ function here should be a single person instead of a single image.
23+
The output of __getitem__ function here should be a single person instead of a single image.
2424
"""
2525

2626
supported_order = ("image", "keypoints", "boxes", "info")
@@ -47,7 +47,7 @@ class COCOJoints(VisionDataset):
4747

4848
min_bbox_h = 0
4949
min_bbox_w = 0
50-
min_box_area = 1500
50+
min_bbox_area = 1500
5151
min_bbox_score = 1e-10
5252

5353
def __init__(
@@ -87,8 +87,6 @@ def __init__(
8787

8888
selected_anns = []
8989
for ann in dataset["annotations"]:
90-
if "image_id" in ann.keys() and ann["image_id"] not in self.ids:
91-
continue
9290

9391
if "iscrowd" in ann.keys() and ann["iscrowd"]:
9492
continue
@@ -129,8 +127,8 @@ def __getitem__(self, index):
129127
img_id = ann["image_id"]
130128
target = []
131129
for k in self.order:
132-
if k == "image":
133130

131+
if k == "image":
134132
file_name = self.imgs[img_id]["file_name"]
135133
img_path = osp.join(self.root, self.image_set, file_name)
136134
image = cv2.imread(img_path, cv2.IMREAD_COLOR)
@@ -186,13 +184,9 @@ def __init__(
186184

187185
self.stride = image_shape[1] // heatmap_shape[1]
188186

189-
x = np.arange(0, heatmap_shape[1], 1)
190-
y = np.arange(0, heatmap_shape[0], 1)
191-
192-
grid_x, grid_y = np.meshgrid(x, y)
193-
194-
self.grid_x = grid_x[None].repeat(keypoint_num, 0)
195-
self.grid_y = grid_y[None].repeat(keypoint_num, 0)
187+
ax = (np.arange(0, heatmap_shape[1]) + 0.5) * self.stride - 0.5
188+
ay = (np.arange(0, heatmap_shape[0]) + 0.5) * self.stride - 0.5
189+
self.grid_x, self.grid_y = np.meshgrid(ax, ay)
196190

197191
def apply(self, inputs):
198192
"""
@@ -204,27 +198,21 @@ def apply(self, inputs):
204198

205199
batch_data["data"].append(image)
206200

207-
joints = (keypoints[0, :, :2] + 0.5) / self.stride - 0.5
208-
heat_valid = np.array(keypoints[0, :, -1]).astype(np.float32)
209-
dis = (self.grid_x - joints[:, 0, np.newaxis, np.newaxis]) ** 2 + (
210-
self.grid_y - joints[:, 1, np.newaxis, np.newaxis]
201+
joint = keypoints[0, :, :2]
202+
dis = (self.grid_x[None] - joint[:, 0, None, None]) ** 2 + (
203+
self.grid_y[None] - joint[:, 1, None, None]
211204
) ** 2
205+
heat_valid = np.array(keypoints[0, :, -1]).astype(np.float32)
206+
212207
heatmaps = []
213208
for k in self.heat_kernel:
209+
214210
heatmap = np.exp(-dis / 2 / k ** 2)
211+
heatmap[heat_valid < 0.1] = 0
215212
heatmap[heatmap < self.heat_thr] = 0
216-
heatmap[heat_valid == 0] = 0
217-
sum_for_norm = heatmap.sum((1, 2))
218-
heatmap[sum_for_norm > 0] = (
219-
heatmap[sum_for_norm > 0]
220-
/ sum_for_norm[sum_for_norm > 0][:, None, None]
221-
)
222-
maxi = np.max(heatmap, (1, 2))
223-
heatmap[maxi > 1e-5] = (
224-
heatmap[maxi > 1e-5]
225-
/ maxi[:, None, None][maxi > 1e-5]
226-
* self.heat_range
227-
)
213+
214+
heatmap *= self.heat_range
215+
228216
heatmaps.append(heatmap)
229217

230218
batch_data["heatmap"].append(np.array(heatmaps))

0 commit comments

Comments
 (0)