Skip to content

Commit 2be44ad

Browse files
authored
Merge pull request #1052 from tensorflow/sync
Sync (0-30 of 206)
2 parents 906be52 + bb2193e commit 2be44ad

File tree

32 files changed

+554
-319
lines changed

32 files changed

+554
-319
lines changed

models/official/detection/dataloader/tf_example_decoder.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,11 @@ def __init__(
3838
# copypara:strip_end
3939
regenerate_source_id=False,
4040
label_key='image/object/class/label',
41-
label_dtype=tf.int64):
41+
label_dtype=tf.int64,
42+
include_keypoint=False,
43+
num_keypoints_per_instance=0):
4244
self._include_mask = include_mask
45+
self._include_keypoint = include_keypoint
4346
# copypara:strip_begin
4447
self._include_polygon = include_polygon
4548
# copypara:strip_end
@@ -62,6 +65,13 @@ def __init__(
6265
self._label_dtype)
6366
if include_mask:
6467
self._keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string)
68+
if include_keypoint:
69+
self._num_keypoints_per_instance = num_keypoints_per_instance
70+
self._keys_to_features.update({
71+
'image/object/keypoint/visibility': tf.io.VarLenFeature(tf.int64),
72+
'image/object/keypoint/x': tf.io.VarLenFeature(tf.float32),
73+
'image/object/keypoint/y': tf.io.VarLenFeature(tf.float32),
74+
})
6575

6676
def _decode_image(self, parsed_tensors):
6777
"""Decodes the image and set its static shape."""
@@ -106,6 +116,17 @@ def _decode_areas(self, parsed_tensors):
106116
lambda: parsed_tensors['image/object/area'],
107117
lambda: (xmax - xmin) * (ymax - ymin) * height * width)
108118

119+
def _decode_keypoints(self, parsed_tensors):
120+
"""Decode keypoint coordinates and visibilities."""
121+
keypoint_x = parsed_tensors['image/object/keypoint/x']
122+
keypoint_y = parsed_tensors['image/object/keypoint/y']
123+
keypoints = tf.stack([keypoint_y, keypoint_x], axis=-1)
124+
keypoints = tf.reshape(keypoints, [-1, self._num_keypoints_per_instance, 2])
125+
keypoint_visibilities = parsed_tensors['image/object/keypoint/visibility']
126+
keypoint_visibilities = tf.reshape(keypoint_visibilities,
127+
[-1, self._num_keypoints_per_instance])
128+
return keypoints, keypoint_visibilities
129+
109130
def decode(self, serialized_example):
110131
"""Decode the serialized example.
111132
@@ -165,6 +186,8 @@ def decode(self, serialized_example):
165186
lambda: _get_source_id_from_encoded_image(parsed_tensors))
166187
if self._include_mask:
167188
masks = self._decode_masks(parsed_tensors)
189+
if self._include_keypoint:
190+
keypoints, keypoint_visibilities = self._decode_keypoints(parsed_tensors)
168191

169192
groundtruth_classes = parsed_tensors[self._label_key]
170193
decoded_tensors = {
@@ -182,4 +205,9 @@ def decode(self, serialized_example):
182205
'groundtruth_instance_masks': masks,
183206
'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'],
184207
})
208+
if self._include_keypoint:
209+
decoded_tensors.update({
210+
'groundtruth_keypoints': keypoints,
211+
'groundtruth_keypoint_visibilities': keypoint_visibilities,
212+
})
185213
return decoded_tensors

models/official/detection/evaluation/coco_evaluator.py

+117-46
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def __init__(
5555
need_rescale_bboxes=True,
5656
per_category_metrics=False,
5757
remove_invalid_boxes=False,
58+
include_keypoint=False,
59+
need_rescale_keypoints=False,
60+
kpt_oks_sigmas=None
5861
):
5962
"""Constructs COCO evaluation class.
6063
@@ -80,6 +83,13 @@ def __init__(
8083
per_category_metrics: Whether to return per category metrics.
8184
remove_invalid_boxes: A boolean indicating whether to remove invalid box
8285
during evaluation.
86+
include_keypoint: a boolean to indicate whether or not to include the
87+
keypoint eval.
88+
need_rescale_keypoints: If true keypoints in `predictions` will be
89+
rescaled back to absolute values (`image_info` is needed in this case).
90+
kpt_oks_sigmas: The sigmas used to calculate keypoint OKS. See
91+
http://cocodataset.org/#keypoints-eval. When None, it will use the
92+
defaults in COCO.
8393
"""
8494
if annotation_file:
8595
if annotation_file.startswith('gs://'):
@@ -105,7 +115,8 @@ def __init__(
105115
'detection_boxes'
106116
]
107117
self._need_rescale_bboxes = need_rescale_bboxes
108-
if self._need_rescale_bboxes:
118+
self._need_rescale_keypoints = need_rescale_keypoints
119+
if self._need_rescale_bboxes or self._need_rescale_keypoints:
109120
self._required_prediction_fields.append('image_info')
110121
self._required_groundtruth_fields = [
111122
'source_id', 'height', 'width', 'classes', 'boxes'
@@ -116,6 +127,18 @@ def __init__(
116127
self._required_prediction_fields.extend(['detection_masks'])
117128
self._required_groundtruth_fields.extend(['masks'])
118129
self.remove_invalid_boxes = remove_invalid_boxes
130+
self._include_keypoint = include_keypoint
131+
self._kpt_oks_sigmas = kpt_oks_sigmas
132+
if self._include_keypoint:
133+
keypoint_metric_names = [
134+
'AP', 'AP50', 'AP75', 'APm', 'APl', 'ARmax1', 'ARmax10', 'ARmax100',
135+
'ARm', 'ARl'
136+
]
137+
keypoint_metric_names = ['keypoint_' + x for x in keypoint_metric_names]
138+
self._metric_names.extend(keypoint_metric_names)
139+
self._required_prediction_fields.extend(['detection_keypoints'])
140+
self._required_groundtruth_fields.extend(['keypoints'])
141+
119142
self.reset()
120143

121144
def reset(self):
@@ -168,7 +191,7 @@ def evaluate(self):
168191
coco_eval.evaluate()
169192
coco_eval.accumulate()
170193
coco_eval.summarize()
171-
coco_metrics = coco_eval.stats
194+
metrics = coco_eval.stats
172195

173196
if self._include_mask:
174197
mcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='segm')
@@ -177,11 +200,17 @@ def evaluate(self):
177200
mcoco_eval.accumulate()
178201
mcoco_eval.summarize()
179202
mask_coco_metrics = mcoco_eval.stats
180-
181-
if self._include_mask:
182-
metrics = np.hstack((coco_metrics, mask_coco_metrics))
183-
else:
184-
metrics = coco_metrics
203+
metrics = np.hstack((metrics, mask_coco_metrics))
204+
205+
if self._include_keypoint:
206+
kcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='keypoints',
207+
kpt_oks_sigmas=self._kpt_oks_sigmas)
208+
kcoco_eval.params.imgIds = image_ids
209+
kcoco_eval.evaluate()
210+
kcoco_eval.accumulate()
211+
kcoco_eval.summarize()
212+
keypoint_coco_metrics = kcoco_eval.stats
213+
metrics = np.hstack((metrics, keypoint_coco_metrics))
185214

186215
# Cleans up the internal variables in order for a fresh eval next time.
187216
self.reset()
@@ -192,46 +221,64 @@ def evaluate(self):
192221

193222
# Adds metrics per category.
194223
if self._per_category_metrics and hasattr(coco_eval, 'category_stats'):
195-
for category_index, category_id in enumerate(coco_eval.params.catIds):
196-
metrics_dict['Precision mAP ByCategory/{}'.format(
197-
category_id)] = coco_eval.category_stats[0][category_index].astype(
198-
np.float32)
199-
metrics_dict['Precision mAP ByCategory@50IoU/{}'.format(
200-
category_id)] = coco_eval.category_stats[1][category_index].astype(
201-
np.float32)
202-
metrics_dict['Precision mAP ByCategory@75IoU/{}'.format(
203-
category_id)] = coco_eval.category_stats[2][category_index].astype(
204-
np.float32)
205-
metrics_dict['Precision mAP ByCategory (small) /{}'.format(
206-
category_id)] = coco_eval.category_stats[3][category_index].astype(
207-
np.float32)
208-
metrics_dict['Precision mAP ByCategory (medium) /{}'.format(
209-
category_id)] = coco_eval.category_stats[4][category_index].astype(
210-
np.float32)
211-
metrics_dict['Precision mAP ByCategory (large) /{}'.format(
212-
category_id)] = coco_eval.category_stats[5][category_index].astype(
213-
np.float32)
214-
metrics_dict['Recall AR@1 ByCategory/{}'.format(
215-
category_id)] = coco_eval.category_stats[6][category_index].astype(
216-
np.float32)
217-
metrics_dict['Recall AR@10 ByCategory/{}'.format(
218-
category_id)] = coco_eval.category_stats[7][category_index].astype(
219-
np.float32)
220-
metrics_dict['Recall AR@100 ByCategory/{}'.format(
221-
category_id)] = coco_eval.category_stats[8][category_index].astype(
222-
np.float32)
223-
metrics_dict['Recall AR (small) ByCategory/{}'.format(
224-
category_id)] = coco_eval.category_stats[9][category_index].astype(
225-
np.float32)
226-
metrics_dict['Recall AR (medium) ByCategory/{}'.format(
227-
category_id)] = coco_eval.category_stats[10][category_index].astype(
228-
np.float32)
229-
metrics_dict['Recall AR (large) ByCategory/{}'.format(
230-
category_id)] = coco_eval.category_stats[11][category_index].astype(
231-
np.float32)
224+
metrics_dict.update(self._retrieve_per_category_metrics(coco_eval))
225+
226+
if self._include_keypoint:
227+
metrics_dict.update(self._retrieve_per_category_metrics(
228+
kcoco_eval, prefix='keypoints'))
232229
return metrics_dict
233230

234-
def _process_predictions(self, predictions):
231+
def _retrieve_per_category_metrics(self, coco_eval, prefix=''):
232+
"""Retrieves and per-category metrics and returns them in a dict.
233+
234+
Args:
235+
coco_eval: a cocoeval.COCOeval object containing evaluation data.
236+
prefix: str, A string used to prefix metric names.
237+
238+
Returns:
239+
metrics_dict: A dictionary with per category metrics.
240+
"""
241+
242+
metrics_dict = {}
243+
if prefix:
244+
prefix = prefix + ' '
245+
246+
for category_index, category_id in enumerate(coco_eval.params.catIds):
247+
if 'keypoints' in prefix:
248+
metrics_dict_keys = [
249+
'Precision mAP ByCategory',
250+
'Precision mAP ByCategory@50IoU',
251+
'Precision mAP ByCategory@75IoU',
252+
'Precision mAP ByCategory (medium)',
253+
'Precision mAP ByCategory (large)',
254+
'Recall AR@1 ByCategory',
255+
'Recall AR@10 ByCategory',
256+
'Recall AR@100 ByCategory',
257+
'Recall AR (medium) ByCategory',
258+
'Recall AR (large) ByCategory',
259+
]
260+
else:
261+
metrics_dict_keys = [
262+
'Precision mAP ByCategory',
263+
'Precision mAP ByCategory@50IoU',
264+
'Precision mAP ByCategory@75IoU',
265+
'Precision mAP ByCategory (small)',
266+
'Precision mAP ByCategory (medium)',
267+
'Precision mAP ByCategory (large)',
268+
'Recall AR@1 ByCategory',
269+
'Recall AR@10 ByCategory',
270+
'Recall AR@100 ByCategory',
271+
'Recall AR (small) ByCategory',
272+
'Recall AR (medium) ByCategory',
273+
'Recall AR (large) ByCategory',
274+
]
275+
for idx, key in enumerate(metrics_dict_keys):
276+
metrics_dict[prefix + key + '/{}'.format(
277+
category_id)] = coco_eval.category_stats[idx][
278+
category_index].astype(np.float32)
279+
return metrics_dict
280+
281+
def _process_bboxes_predictions(self, predictions):
235282
image_scale = np.tile(predictions['image_info'][:, 2:3, :], (1, 1, 2))
236283
predictions['detection_boxes'] = (
237284
predictions['detection_boxes'].astype(np.float32))
@@ -241,6 +288,13 @@ def _process_predictions(self, predictions):
241288
predictions['detection_outer_boxes'].astype(np.float32))
242289
predictions['detection_outer_boxes'] /= image_scale
243290

291+
def _process_keypoints_predictions(self, predictions):
292+
image_scale = tf.reshape(predictions['image_info'][:, 2:3, :],
293+
[-1, 1, 1, 2])
294+
predictions['detection_keypoints'] = (
295+
predictions['detection_keypoints'].astype(np.float32))
296+
predictions['detection_keypoints'] /= image_scale
297+
244298
def update(self, predictions, groundtruths=None):
245299
"""Update and aggregate detection results and groundtruth data.
246300
@@ -286,7 +340,9 @@ def update(self, predictions, groundtruths=None):
286340
raise ValueError(
287341
'Missing the required key `{}` in predictions!'.format(k))
288342
if self._need_rescale_bboxes:
289-
self._process_predictions(predictions)
343+
self._process_bboxes_predictions(predictions)
344+
if self._need_rescale_keypoints:
345+
self._process_keypoints_predictions(predictions)
290346
for k, v in six.iteritems(predictions):
291347
if k not in self._predictions:
292348
self._predictions[k] = [v]
@@ -305,6 +361,20 @@ def update(self, predictions, groundtruths=None):
305361
else:
306362
self._groundtruths[k].append(v)
307363

364+
def merge(self, other):
365+
"""Merges the states from the other CocoEvaluator."""
366+
for k, v in other._predictions.items(): # pylint: disable=protected-access
367+
if k not in self._predictions:
368+
self._predictions[k] = v
369+
else:
370+
self._predictions[k].extend(v)
371+
372+
for k, v in other._groundtruths.items(): # pylint: disable=protected-access
373+
if k not in self._groundtruths:
374+
self._groundtruths[k] = v
375+
else:
376+
self._groundtruths[k].extend(v)
377+
308378

309379
class ShapeMaskCOCOEvaluator(COCOEvaluator):
310380
"""COCO evaluation metric class for ShapeMask."""
@@ -463,6 +533,7 @@ def __init__(
463533
self._metric_names.extend(mask_metric_names)
464534
self._required_prediction_fields.extend(['detection_masks'])
465535
self._required_groundtruth_fields.extend(['masks'])
536+
self._need_rescale_keypoints = False
466537

467538
self.reset()
468539

models/official/detection/evaluation/coco_utils.py

+36
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def convert_predictions_to_coco_annotations(
138138
Optional fields:
139139
- detection_masks: a list of numpy arrays of float of shape
140140
[batch_size, K, mask_height, mask_width].
141+
- detection_keypoints: a list of numpy arrays of float of shape
142+
[batch_size, K, num_keypoints, 2]
141143
remove_invalid_boxes: A boolean indicating whether to remove invalid box
142144
during evaluation.
143145
@@ -160,6 +162,19 @@ def convert_predictions_to_coco_annotations(
160162

161163
# NOTE: Batch size may differ between chunks.
162164
batch_size = predictions['source_id'][i].shape[0]
165+
if 'detection_keypoints' in predictions:
166+
# Adds extra ones to indicate the visibility for each keypoint as is
167+
# recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y]
168+
# as mandated by COCO.
169+
num_keypoints = predictions['detection_keypoints'][i].shape[2]
170+
coco_keypoints = np.concatenate(
171+
[
172+
predictions['detection_keypoints'][i][..., 1:],
173+
predictions['detection_keypoints'][i][..., :1],
174+
np.ones([batch_size, max_num_detections, num_keypoints, 1]),
175+
],
176+
axis=-1,
177+
).astype(int)
163178
for j in range(batch_size):
164179
if 'detection_masks' in predictions:
165180
image_masks = mask_utils.paste_instance_masks(
@@ -185,6 +200,8 @@ def convert_predictions_to_coco_annotations(
185200
ann['score'] = predictions['detection_scores'][i][j, k]
186201
if 'detection_masks' in predictions:
187202
ann['segmentation'] = encoded_masks[k]
203+
if 'detection_keypoints' in predictions:
204+
ann['keypoints'] = coco_keypoints[j, k].flatten().tolist()
188205
coco_predictions.append(ann)
189206

190207
for i, ann in enumerate(coco_predictions):
@@ -272,6 +289,25 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
272289
ann['segmentation'] = encoded_mask
273290
if 'areas' not in groundtruths:
274291
ann['area'] = mask_api.area(encoded_mask)
292+
if 'keypoints' in groundtruths:
293+
keypoints = groundtruths['keypoints'][i]
294+
coco_keypoints = []
295+
num_valid_keypoints = 0
296+
for z in range(len(keypoints[j, k, :, 1])):
297+
# Convert from [y, x] to [x, y] as mandated by COCO.
298+
x = float(keypoints[j, k, z, 1])
299+
y = float(keypoints[j, k, z, 0])
300+
coco_keypoints.append(x)
301+
coco_keypoints.append(y)
302+
if tf.math.is_nan(x) or tf.math.is_nan(y) or (
303+
x == 0 and y == 0):
304+
visibility = 0
305+
else:
306+
visibility = 2
307+
num_valid_keypoints = num_valid_keypoints + 1
308+
coco_keypoints.append(visibility)
309+
ann['keypoints'] = coco_keypoints
310+
ann['num_keypoints'] = num_valid_keypoints
275311
gt_annotations.append(ann)
276312

277313
for i, ann in enumerate(gt_annotations):

0 commit comments

Comments
 (0)