@@ -55,6 +55,9 @@ def __init__(
55
55
need_rescale_bboxes = True ,
56
56
per_category_metrics = False ,
57
57
remove_invalid_boxes = False ,
58
+ include_keypoint = False ,
59
+ need_rescale_keypoints = False ,
60
+ kpt_oks_sigmas = None
58
61
):
59
62
"""Constructs COCO evaluation class.
60
63
@@ -80,6 +83,13 @@ def __init__(
80
83
per_category_metrics: Whether to return per category metrics.
81
84
remove_invalid_boxes: A boolean indicating whether to remove invalid box
82
85
during evaluation.
86
+ include_keypoint: a boolean to indicate whether or not to include the
87
+ keypoint eval.
88
+ need_rescale_keypoints: If true keypoints in `predictions` will be
89
+ rescaled back to absolute values (`image_info` is needed in this case).
90
+ kpt_oks_sigmas: The sigmas used to calculate keypoint OKS. See
91
+ http://cocodataset.org/#keypoints-eval. When None, it will use the
92
+ defaults in COCO.
83
93
"""
84
94
if annotation_file :
85
95
if annotation_file .startswith ('gs://' ):
@@ -105,7 +115,8 @@ def __init__(
105
115
'detection_boxes'
106
116
]
107
117
self ._need_rescale_bboxes = need_rescale_bboxes
108
- if self ._need_rescale_bboxes :
118
+ self ._need_rescale_keypoints = need_rescale_keypoints
119
+ if self ._need_rescale_bboxes or self ._need_rescale_keypoints :
109
120
self ._required_prediction_fields .append ('image_info' )
110
121
self ._required_groundtruth_fields = [
111
122
'source_id' , 'height' , 'width' , 'classes' , 'boxes'
@@ -116,6 +127,18 @@ def __init__(
116
127
self ._required_prediction_fields .extend (['detection_masks' ])
117
128
self ._required_groundtruth_fields .extend (['masks' ])
118
129
self .remove_invalid_boxes = remove_invalid_boxes
130
+ self ._include_keypoint = include_keypoint
131
+ self ._kpt_oks_sigmas = kpt_oks_sigmas
132
+ if self ._include_keypoint :
133
+ keypoint_metric_names = [
134
+ 'AP' , 'AP50' , 'AP75' , 'APm' , 'APl' , 'ARmax1' , 'ARmax10' , 'ARmax100' ,
135
+ 'ARm' , 'ARl'
136
+ ]
137
+ keypoint_metric_names = ['keypoint_' + x for x in keypoint_metric_names ]
138
+ self ._metric_names .extend (keypoint_metric_names )
139
+ self ._required_prediction_fields .extend (['detection_keypoints' ])
140
+ self ._required_groundtruth_fields .extend (['keypoints' ])
141
+
119
142
self .reset ()
120
143
121
144
def reset (self ):
@@ -168,7 +191,7 @@ def evaluate(self):
168
191
coco_eval .evaluate ()
169
192
coco_eval .accumulate ()
170
193
coco_eval .summarize ()
171
- coco_metrics = coco_eval .stats
194
+ metrics = coco_eval .stats
172
195
173
196
if self ._include_mask :
174
197
mcoco_eval = cocoeval .COCOeval (coco_gt , coco_dt , iouType = 'segm' )
@@ -177,11 +200,17 @@ def evaluate(self):
177
200
mcoco_eval .accumulate ()
178
201
mcoco_eval .summarize ()
179
202
mask_coco_metrics = mcoco_eval .stats
180
-
181
- if self ._include_mask :
182
- metrics = np .hstack ((coco_metrics , mask_coco_metrics ))
183
- else :
184
- metrics = coco_metrics
203
+ metrics = np .hstack ((metrics , mask_coco_metrics ))
204
+
205
+ if self ._include_keypoint :
206
+ kcoco_eval = cocoeval .COCOeval (coco_gt , coco_dt , iouType = 'keypoints' ,
207
+ kpt_oks_sigmas = self ._kpt_oks_sigmas )
208
+ kcoco_eval .params .imgIds = image_ids
209
+ kcoco_eval .evaluate ()
210
+ kcoco_eval .accumulate ()
211
+ kcoco_eval .summarize ()
212
+ keypoint_coco_metrics = kcoco_eval .stats
213
+ metrics = np .hstack ((metrics , keypoint_coco_metrics ))
185
214
186
215
# Cleans up the internal variables in order for a fresh eval next time.
187
216
self .reset ()
@@ -192,46 +221,64 @@ def evaluate(self):
192
221
193
222
# Adds metrics per category.
194
223
if self ._per_category_metrics and hasattr (coco_eval , 'category_stats' ):
195
- for category_index , category_id in enumerate (coco_eval .params .catIds ):
196
- metrics_dict ['Precision mAP ByCategory/{}' .format (
197
- category_id )] = coco_eval .category_stats [0 ][category_index ].astype (
198
- np .float32 )
199
- metrics_dict ['Precision mAP ByCategory@50IoU/{}' .format (
200
- category_id )] = coco_eval .category_stats [1 ][category_index ].astype (
201
- np .float32 )
202
- metrics_dict ['Precision mAP ByCategory@75IoU/{}' .format (
203
- category_id )] = coco_eval .category_stats [2 ][category_index ].astype (
204
- np .float32 )
205
- metrics_dict ['Precision mAP ByCategory (small) /{}' .format (
206
- category_id )] = coco_eval .category_stats [3 ][category_index ].astype (
207
- np .float32 )
208
- metrics_dict ['Precision mAP ByCategory (medium) /{}' .format (
209
- category_id )] = coco_eval .category_stats [4 ][category_index ].astype (
210
- np .float32 )
211
- metrics_dict ['Precision mAP ByCategory (large) /{}' .format (
212
- category_id )] = coco_eval .category_stats [5 ][category_index ].astype (
213
- np .float32 )
214
- metrics_dict ['Recall AR@1 ByCategory/{}' .format (
215
- category_id )] = coco_eval .category_stats [6 ][category_index ].astype (
216
- np .float32 )
217
- metrics_dict ['Recall AR@10 ByCategory/{}' .format (
218
- category_id )] = coco_eval .category_stats [7 ][category_index ].astype (
219
- np .float32 )
220
- metrics_dict ['Recall AR@100 ByCategory/{}' .format (
221
- category_id )] = coco_eval .category_stats [8 ][category_index ].astype (
222
- np .float32 )
223
- metrics_dict ['Recall AR (small) ByCategory/{}' .format (
224
- category_id )] = coco_eval .category_stats [9 ][category_index ].astype (
225
- np .float32 )
226
- metrics_dict ['Recall AR (medium) ByCategory/{}' .format (
227
- category_id )] = coco_eval .category_stats [10 ][category_index ].astype (
228
- np .float32 )
229
- metrics_dict ['Recall AR (large) ByCategory/{}' .format (
230
- category_id )] = coco_eval .category_stats [11 ][category_index ].astype (
231
- np .float32 )
224
+ metrics_dict .update (self ._retrieve_per_category_metrics (coco_eval ))
225
+
226
+ if self ._include_keypoint :
227
+ metrics_dict .update (self ._retrieve_per_category_metrics (
228
+ kcoco_eval , prefix = 'keypoints' ))
232
229
return metrics_dict
233
230
234
- def _process_predictions (self , predictions ):
231
+ def _retrieve_per_category_metrics (self , coco_eval , prefix = '' ):
232
+ """Retrieves and per-category metrics and returns them in a dict.
233
+
234
+ Args:
235
+ coco_eval: a cocoeval.COCOeval object containing evaluation data.
236
+ prefix: str, A string used to prefix metric names.
237
+
238
+ Returns:
239
+ metrics_dict: A dictionary with per category metrics.
240
+ """
241
+
242
+ metrics_dict = {}
243
+ if prefix :
244
+ prefix = prefix + ' '
245
+
246
+ for category_index , category_id in enumerate (coco_eval .params .catIds ):
247
+ if 'keypoints' in prefix :
248
+ metrics_dict_keys = [
249
+ 'Precision mAP ByCategory' ,
250
+ 'Precision mAP ByCategory@50IoU' ,
251
+ 'Precision mAP ByCategory@75IoU' ,
252
+ 'Precision mAP ByCategory (medium)' ,
253
+ 'Precision mAP ByCategory (large)' ,
254
+ 'Recall AR@1 ByCategory' ,
255
+ 'Recall AR@10 ByCategory' ,
256
+ 'Recall AR@100 ByCategory' ,
257
+ 'Recall AR (medium) ByCategory' ,
258
+ 'Recall AR (large) ByCategory' ,
259
+ ]
260
+ else :
261
+ metrics_dict_keys = [
262
+ 'Precision mAP ByCategory' ,
263
+ 'Precision mAP ByCategory@50IoU' ,
264
+ 'Precision mAP ByCategory@75IoU' ,
265
+ 'Precision mAP ByCategory (small)' ,
266
+ 'Precision mAP ByCategory (medium)' ,
267
+ 'Precision mAP ByCategory (large)' ,
268
+ 'Recall AR@1 ByCategory' ,
269
+ 'Recall AR@10 ByCategory' ,
270
+ 'Recall AR@100 ByCategory' ,
271
+ 'Recall AR (small) ByCategory' ,
272
+ 'Recall AR (medium) ByCategory' ,
273
+ 'Recall AR (large) ByCategory' ,
274
+ ]
275
+ for idx , key in enumerate (metrics_dict_keys ):
276
+ metrics_dict [prefix + key + '/{}' .format (
277
+ category_id )] = coco_eval .category_stats [idx ][
278
+ category_index ].astype (np .float32 )
279
+ return metrics_dict
280
+
281
+ def _process_bboxes_predictions (self , predictions ):
235
282
image_scale = np .tile (predictions ['image_info' ][:, 2 :3 , :], (1 , 1 , 2 ))
236
283
predictions ['detection_boxes' ] = (
237
284
predictions ['detection_boxes' ].astype (np .float32 ))
@@ -241,6 +288,13 @@ def _process_predictions(self, predictions):
241
288
predictions ['detection_outer_boxes' ].astype (np .float32 ))
242
289
predictions ['detection_outer_boxes' ] /= image_scale
243
290
291
+ def _process_keypoints_predictions (self , predictions ):
292
+ image_scale = tf .reshape (predictions ['image_info' ][:, 2 :3 , :],
293
+ [- 1 , 1 , 1 , 2 ])
294
+ predictions ['detection_keypoints' ] = (
295
+ predictions ['detection_keypoints' ].astype (np .float32 ))
296
+ predictions ['detection_keypoints' ] /= image_scale
297
+
244
298
def update (self , predictions , groundtruths = None ):
245
299
"""Update and aggregate detection results and groundtruth data.
246
300
@@ -286,7 +340,9 @@ def update(self, predictions, groundtruths=None):
286
340
raise ValueError (
287
341
'Missing the required key `{}` in predictions!' .format (k ))
288
342
if self ._need_rescale_bboxes :
289
- self ._process_predictions (predictions )
343
+ self ._process_bboxes_predictions (predictions )
344
+ if self ._need_rescale_keypoints :
345
+ self ._process_keypoints_predictions (predictions )
290
346
for k , v in six .iteritems (predictions ):
291
347
if k not in self ._predictions :
292
348
self ._predictions [k ] = [v ]
@@ -305,6 +361,20 @@ def update(self, predictions, groundtruths=None):
305
361
else :
306
362
self ._groundtruths [k ].append (v )
307
363
364
+ def merge (self , other ):
365
+ """Merges the states from the other CocoEvaluator."""
366
+ for k , v in other ._predictions .items (): # pylint: disable=protected-access
367
+ if k not in self ._predictions :
368
+ self ._predictions [k ] = v
369
+ else :
370
+ self ._predictions [k ].extend (v )
371
+
372
+ for k , v in other ._groundtruths .items (): # pylint: disable=protected-access
373
+ if k not in self ._groundtruths :
374
+ self ._groundtruths [k ] = v
375
+ else :
376
+ self ._groundtruths [k ].extend (v )
377
+
308
378
309
379
class ShapeMaskCOCOEvaluator (COCOEvaluator ):
310
380
"""COCO evaluation metric class for ShapeMask."""
@@ -463,6 +533,7 @@ def __init__(
463
533
self ._metric_names .extend (mask_metric_names )
464
534
self ._required_prediction_fields .extend (['detection_masks' ])
465
535
self ._required_groundtruth_fields .extend (['masks' ])
536
+ self ._need_rescale_keypoints = False
466
537
467
538
self .reset ()
468
539
0 commit comments