Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions configs/_base_/datasets/cityscapes_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@
metric=['bbox', 'segm']),
dict(
type='CityScapesMetric',
ann_file=data_root +
'annotations/instancesonly_filtered_gtFine_val.json',
seg_prefix=data_root + '/gtFine/val',
outfile_prefix='./work_dirs/cityscapes_metric/instance')
seg_prefix=data_root + 'gtFine/val',
classwise=True)
]

test_evaluator = val_evaluator
Expand Down
159 changes: 62 additions & 97 deletions mmdet/evaluation/metrics/cityscapes_metric.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,27 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import shutil
from collections import OrderedDict
from typing import Dict, Optional, Sequence
import warnings
from typing import Optional, Sequence

import mmcv
import numpy as np
from mmengine.dist import is_main_process, master_only
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger
from mmeval import CityScapesDetection

from mmdet.registry import METRICS

try:
import cityscapesscripts
from cityscapesscripts.evaluation import \
evalInstanceLevelSemanticLabeling as CSEval
from cityscapesscripts.helpers import labels as CSLabels
except ImportError:
cityscapesscripts = None
CSLabels = None
CSEval = None


@METRICS.register_module()
class CityScapesMetric(BaseMetric):
"""CityScapes metric for instance segmentation.
class CityScapesMetric(CityScapesDetection):
"""A wrapper of :class:`mmeval.CityScapesDetection`.

Args:
outfile_prefix (str): The prefix of txt and png files. The txt and
png file will be save in a directory whose path is
"outfile_prefix.results/".
outfile_prefix (str): The prefix of txt and png files. It is the
saving path of txt and png file, e.g. "a/b/prefix".
seg_prefix (str, optional): Path to the directory which contains the
cityscapes instance segmentation masks. It's necessary when
training and validation. It could be None when infer on test
Expand All @@ -40,8 +30,8 @@ class CityScapesMetric(BaseMetric):
evaluation. It is useful when you want to format the result
to a specific format and submit it to the test server.
Defaults to False.
keep_results (bool): Whether to keep the results. When ``format_only``
is True, ``keep_results`` must be True. Defaults to False.
classwise (bool): Whether to return the computed results of each
class. Defaults to True.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
Expand All @@ -53,36 +43,37 @@ class CityScapesMetric(BaseMetric):
default_prefix: Optional[str] = 'cityscapes'

def __init__(self,
outfile_prefix: str,
outfile_prefix: Optional[str] = None,
seg_prefix: Optional[str] = None,
format_only: bool = False,
keep_results: bool = False,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
classwise: bool = True,
prefix: Optional[str] = None,
dist_backend: str = 'torch_cuda',
**kwargs) -> None:

if cityscapesscripts is None:
raise RuntimeError('Please run "pip install cityscapesscripts" to '
raise RuntimeError('Please run `pip install cityscapesscripts` to '
'install cityscapesscripts first.')

assert outfile_prefix, 'outfile_prefix must be not None.'

if format_only:
assert keep_results, 'keep_results must be True when '
'format_only is True'
collect_device = kwargs.pop('collect_device', None)
if collect_device is not None:
warnings.warn(
'DeprecationWarning: The `collect_device` parameter of '
'`CityScapesMetric` is deprecated, use `dist_backend` '
'instead.')

super().__init__(collect_device=collect_device, prefix=prefix)
self.format_only = format_only
self.keep_results = keep_results
self.seg_out_dir = osp.abspath(f'{outfile_prefix}.results')
self.seg_prefix = seg_prefix
logger = MMLogger.get_current_instance()

if is_main_process():
os.makedirs(self.seg_out_dir, exist_ok=True)
super().__init__(
outfile_prefix=outfile_prefix,
seg_prefix=seg_prefix,
format_only=format_only,
classwise=classwise,
logger=logger,
dist_backend=dist_backend,
**kwargs)

@master_only
def __del__(self) -> None:
"""Clean up."""
if not self.keep_results:
shutil.rmtree(self.seg_out_dir)
self.prefix = prefix or self.default_prefix

# TODO: data_batch is no longer needed, consider adjusting the
# parameter position
Expand All @@ -96,77 +87,51 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
data_samples (Sequence[dict]): A batch of data samples that
contain annotations and predictions.
"""
predictions, groundtruths = [], []

for data_sample in data_samples:
# parse pred
result = dict()
pred = data_sample['pred_instances']
pred = dict()
pred_instances = data_sample['pred_instances']
filename = data_sample['img_path']
basename = osp.splitext(osp.basename(filename))[0]
pred_txt = osp.join(self.seg_out_dir, basename + '_pred.txt')
result['pred_txt'] = pred_txt
labels = pred['labels'].cpu().numpy()
masks = pred['masks'].cpu().numpy().astype(np.uint8)
if 'mask_scores' in pred:
labels = pred_instances['labels'].cpu().numpy()
masks = pred_instances['masks'].cpu().numpy().astype(np.uint8)
if 'mask_scores' in pred_instances:
# some detectors use different scores for bbox and mask
mask_scores = pred['mask_scores'].cpu().numpy()
mask_scores = pred_instances['mask_scores'].cpu().numpy()
else:
mask_scores = pred['scores'].cpu().numpy()

with open(pred_txt, 'w') as f:
for i, (label, mask, mask_score) in enumerate(
zip(labels, masks, mask_scores)):
class_name = self.dataset_meta['classes'][label]
class_id = CSLabels.name2label[class_name].id
png_filename = osp.join(
self.seg_out_dir, basename + f'_{i}_{class_name}.png')
mmcv.imwrite(mask, png_filename)
f.write(f'{osp.basename(png_filename)} '
f'{class_id} {mask_score}\n')
mask_scores = pred_instances['scores'].cpu().numpy()

pred['labels'] = labels
pred['masks'] = masks
pred['mask_scores'] = mask_scores
pred['basename'] = basename
predictions.append(pred)

# parse gt
gt = dict()
img_path = filename.replace('leftImg8bit.png',
'gtFine_instanceIds.png')
img_path = img_path.replace('leftImg8bit', 'gtFine')
gt['file_name'] = osp.join(self.seg_prefix, img_path)
gt['file_name'] = img_path
groundtruths.append(gt)

self.results.append((gt, result))
self.add(predictions, groundtruths)

def compute_metrics(self, results: list) -> Dict[str, float]:
"""Compute the metrics from processed results.
def evaluate(self, *args, **kwargs) -> dict:
"""Returns metric results and print pretty table of metrics per class.

Args:
results (list): The processed results of each batch.

Returns:
Dict[str, float]: The computed metrics. The keys are the names of
the metrics, and the values are corresponding results.
This method would be invoked by ``mmengine.Evaluator``.
"""
logger: MMLogger = MMLogger.get_current_instance()
metric_results = self.compute(*args, **kwargs)
self.reset()

if self.format_only:
logger.info(
f'results are saved to {osp.dirname(self.seg_out_dir)}')
return OrderedDict()
logger.info('starts to compute metric')

gts, preds = zip(*results)
# set global states in cityscapes evaluation API
CSEval.args.cityscapesPath = osp.join(self.seg_prefix, '../..')
CSEval.args.predictionPath = self.seg_out_dir
CSEval.args.predictionWalk = None
CSEval.args.JSONOutput = False
CSEval.args.colorized = False
CSEval.args.gtInstancesFile = osp.join(self.seg_out_dir,
'gtInstances.json')

groundTruthImgList = [gt['file_name'] for gt in gts]
predictionImgList = [pred['pred_txt'] for pred in preds]
CSEval_results = CSEval.evaluateImgLists(predictionImgList,
groundTruthImgList,
CSEval.args)['averages']
eval_results = OrderedDict()
eval_results['mAP'] = CSEval_results['allAp']
eval_results['AP@50'] = CSEval_results['allAp50%']

return eval_results
return metric_results

evaluate_results = {
f'{self.prefix}/{k}(%)': round(float(v) * 100, 4)
for k, v in metric_results.items()
}
return evaluate_results
Loading