diff --git a/requirements.txt b/requirements.txt index 0de7ebd..26cd97f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch>=2.0.1 torchvision>=0.15.2 -pycocotools +faster-coco-eval>=1.6.5 PyYAML tensorboard scipy diff --git a/src/data/dataset/coco_dataset.py b/src/data/dataset/coco_dataset.py index cef53fb..f8ce50a 100644 --- a/src/data/dataset/coco_dataset.py +++ b/src/data/dataset/coco_dataset.py @@ -11,12 +11,14 @@ import torchvision from PIL import Image -from pycocotools import mask as coco_mask - +import faster_coco_eval +import faster_coco_eval.core.mask as coco_mask from ._dataset import DetDataset from .._misc import convert_to_tv_tensor from ...core import register + torchvision.disable_beta_transforms_warning() +faster_coco_eval.init_as_pycocotools() __all__ = ['CocoDetection'] diff --git a/src/data/dataset/coco_eval.py b/src/data/dataset/coco_eval.py index d02c6bc..75f6bd8 100644 --- a/src/data/dataset/coco_eval.py +++ b/src/data/dataset/coco_eval.py @@ -11,13 +11,10 @@ import numpy as np import torch -from pycocotools.cocoeval import COCOeval -from pycocotools.coco import COCO -import pycocotools.mask as mask_util - -from ...misc import dist_utils +from faster_coco_eval import COCO, COCOeval_faster +import faster_coco_eval.core.mask as mask_util from ...core import register - +from ...misc import dist_utils __all__ = ['CocoEvaluator',] @@ -26,12 +23,12 @@ class CocoEvaluator(object): def __init__(self, coco_gt, iou_types): assert isinstance(iou_types, (list, tuple)) coco_gt = copy.deepcopy(coco_gt) - self.coco_gt = coco_gt + self.coco_gt : COCO = coco_gt self.iou_types = iou_types self.coco_eval = {} for iou_type in iou_types: - self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + self.coco_eval[iou_type] = COCOeval_faster(coco_gt, iouType=iou_type, print_function=print, separate_eval=True) self.img_ids = [] self.eval_imgs = {k: [] for k in iou_types} @@ -39,7 +36,7 @@ def __init__(self, coco_gt, iou_types): def cleanup(self): self.coco_eval = {} for iou_type in self.iou_types: - self.coco_eval[iou_type] = COCOeval(self.coco_gt, iouType=iou_type) + self.coco_eval[iou_type] = COCOeval_faster(self.coco_gt, iouType=iou_type, print_function=print, separate_eval=True) self.img_ids = [] self.eval_imgs = {k: [] for k in self.iou_types} @@ -50,23 +47,26 @@ def update(self, predictions): for iou_type in self.iou_types: results = self.prepare(predictions, iou_type) + coco_eval = self.coco_eval[iou_type] # suppress pycocotools prints with open(os.devnull, 'w') as devnull: with contextlib.redirect_stdout(devnull): - coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() - coco_eval = self.coco_eval[iou_type] + coco_dt = self.coco_gt.loadRes(results) if results else COCO() + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + coco_eval.evaluate() - coco_eval.cocoDt = coco_dt - coco_eval.params.imgIds = list(img_ids) - img_ids, eval_imgs = evaluate(coco_eval) - - self.eval_imgs[iou_type].append(eval_imgs) + self.eval_imgs[iou_type].append(np.array(coco_eval._evalImgs_cpp).reshape(len(coco_eval.params.catIds), len(coco_eval.params.areaRng), len(coco_eval.params.imgIds))) def synchronize_between_processes(self): for iou_type in self.iou_types: - self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) - create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + img_ids, eval_imgs = merge(self.img_ids, self.eval_imgs[iou_type]) + + coco_eval = self.coco_eval[iou_type] + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + coco_eval._evalImgs_cpp = eval_imgs def accumulate(self): for coco_eval in self.coco_eval.values(): @@ -177,7 +177,6 @@ def convert_to_xywh(boxes): xmin, ymin, xmax, ymax = boxes.unbind(1) return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) - def merge(img_ids, eval_imgs): all_img_ids = dist_utils.all_gather(img_ids) all_eval_imgs = dist_utils.all_gather(eval_imgs) @@ -188,89 +187,14 @@ def merge(img_ids, eval_imgs): merged_eval_imgs = [] for p in all_eval_imgs: - merged_eval_imgs.append(p) + merged_eval_imgs.extend(p) + merged_img_ids = np.array(merged_img_ids) - merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + merged_eval_imgs = np.concatenate(merged_eval_imgs, axis=2).ravel() + # merged_eval_imgs = np.array(merged_eval_imgs).T.ravel() # keep only unique (and in sorted order) images merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) - merged_eval_imgs = merged_eval_imgs[..., idx] - - return merged_img_ids, merged_eval_imgs - - -def create_common_coco_eval(coco_eval, img_ids, eval_imgs): - img_ids, eval_imgs = merge(img_ids, eval_imgs) - img_ids = list(img_ids) - eval_imgs = list(eval_imgs.flatten()) - - coco_eval.evalImgs = eval_imgs - coco_eval.params.imgIds = img_ids - coco_eval._paramsEval = copy.deepcopy(coco_eval.params) - - -################################################################# -# From pycocotools, just removed the prints and fixed -# a Python3 bug about unicode not defined -################################################################# - - -# import io -# from contextlib import redirect_stdout -# def evaluate(imgs): -# with redirect_stdout(io.StringIO()): -# imgs.evaluate() -# return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds)) - - -def evaluate(self): - """ - Run per image evaluation on given images and store results (a list of dict) in self.evalImgs - :return: None - """ - # tic = time.time() - # print('Running per image evaluation...') - p = self.params - # add backward compatibility if useSegm is specified in params - if p.useSegm is not None: - p.iouType = 'segm' if p.useSegm == 1 else 'bbox' - print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) - # print('Evaluate annotation type *{}*'.format(p.iouType)) - p.imgIds = list(np.unique(p.imgIds)) - if p.useCats: - p.catIds = list(np.unique(p.catIds)) - p.maxDets = sorted(p.maxDets) - self.params = p - - self._prepare() - # loop through images, area range, max detection number - catIds = p.catIds if p.useCats else [-1] - - if p.iouType == 'segm' or p.iouType == 'bbox': - computeIoU = self.computeIoU - elif p.iouType == 'keypoints': - computeIoU = self.computeOks - self.ious = { - (imgId, catId): computeIoU(imgId, catId) - for imgId in p.imgIds - for catId in catIds} - - evaluateImg = self.evaluateImg - maxDet = p.maxDets[-1] - evalImgs = [ - evaluateImg(imgId, catId, areaRng, maxDet) - for catId in catIds - for areaRng in p.areaRng - for imgId in p.imgIds - ] - # this is NOT in the pycocotools code, but could be done outside - evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) - self._paramsEval = copy.deepcopy(self.params) - # toc = time.time() - # print('DONE (t={:0.2f}s).'.format(toc-tic)) - return p.imgIds, evalImgs - -################################################################# -# end of straight copy from pycocotools, just removing the prints -################################################################# + + return merged_img_ids.tolist(), merged_eval_imgs.tolist() diff --git a/src/data/dataset/coco_fasteval.py b/src/data/dataset/coco_fasteval.py deleted file mode 100644 index 4cd9202..0000000 --- a/src/data/dataset/coco_fasteval.py +++ /dev/null @@ -1,139 +0,0 @@ -""" -Copyright (c) Facebook, Inc. and its affiliates. -Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -# The code is based on -https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/fast_eval_api.py -https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/metrics/fast_cocoeval -""" - -import copy -import time - -import numpy as np -from cocoeval_ext import ( - InstanceAnnotation, - COCOevalEvaluateImages, - COCOevalAccumulate -) -from pycocotools.cocoeval import COCOeval - -__all__ = ['FastCOCOeval'] - - -class FastCOCOeval(COCOeval): - """ - This is a slightly modified version of the original COCO API, where the functions evaluateImg() - and accumulate() are implemented in C++ to speedup evaluation - """ - - def evaluate(self): - """ - Run per image evaluation on given images and store results in self.evalImgs_cpp, a - datastructure that isn't readable from Python but is used by a c++ implementation of - accumulate(). Unlike the original COCO PythonAPI, we don't populate the datastructure - self.evalImgs because this datastructure is a computational bottleneck. - :return: None - """ - tic = time.time() - print('Running per image evaluation...') - p = self.params - # add backward compatibility if useSegm is specified in params - if p.useSegm is not None: - p.iouType = "segm" if p.useSegm == 1 else "bbox" - print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) - print('Evaluate annotation type *{}*'.format(p.iouType)) - p.imgIds = list(np.unique(p.imgIds)) - if p.useCats: - p.catIds = list(np.unique(p.catIds)) - p.maxDets = sorted(p.maxDets) - self.params = p - - self._prepare() # bottleneck - - # loop through images, area range, max detection number - catIds = p.catIds if p.useCats else [-1] - - if p.iouType == "segm" or p.iouType == "bbox": - computeIoU = self.computeIoU - elif p.iouType == "keypoints": - computeIoU = self.computeOks - self.ious = { - (imgId, catId): computeIoU(imgId, catId) - for imgId in p.imgIds for catId in catIds - } # bottleneck - - maxDet = p.maxDets[-1] - - # <<<< Beginning of code differences with original COCO API - def convert_instances_to_cpp(instances, is_det=False) -> None: - # Convert annotations for a list of instances in an image to a format that's fast - # to access in C++ - instances_cpp = [] - for instance in instances: - instance_cpp = InstanceAnnotation( - int(instance["id"]), - instance["score"] if is_det else instance.get("score", 0.0), - instance["area"], - bool(instance.get("iscrowd", 0)), - bool(instance.get("ignore", 0)), - ) - instances_cpp.append(instance_cpp) - return instances_cpp - - # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++ - ground_truth_instances = [ - [convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds] - for imgId in p.imgIds - ] - detected_instances = [ - [convert_instances_to_cpp(self._dts[imgId, catId], is_det=True) for catId in p.catIds] - for imgId in p.imgIds - ] - ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds] - - if not p.useCats: - # For each image, flatten per-category lists into a single list - ground_truth_instances = [[[o for c in i for o in c]] for i in ground_truth_instances] - detected_instances = [[[o for c in i for o in c]] for i in detected_instances] - - # Call C++ implementation of self.evaluateImgs() - self._evalImgs_cpp = COCOevalEvaluateImages( - p.areaRng, maxDet, p.iouThrs, ious, ground_truth_instances, detected_instances - ) - self._evalImgs = None - - self._paramsEval = copy.deepcopy(self.params) - toc = time.time() - print('DONE (t={:0.2f}s).'.format(toc - tic)) - # >>>> End of code differences with original COCO API - - # this is NOT in the pycocotools code, but could be done outside - # evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) - # return p.imgIds, evalImgs - - - def accumulate(self, p=None): - """ - Accumulate per image evaluation results and store the result in self.eval. Does not - support changing parameter settings from those used by self.evaluate() - """ - print('Accumulating evaluation results...') - tic = time.time() - assert hasattr( - self, "_evalImgs_cpp" - ), "evaluate() must be called before accmulate() is called." - - self.eval = COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp) - - # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections - self.eval["recall"] = np.array(self.eval["recall"]).reshape( - self.eval["counts"][:1] + self.eval["counts"][2:] - ) - - # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X - # num_area_ranges X num_max_detections - self.eval["precision"] = np.array(self.eval["precision"]).reshape(self.eval["counts"]) - self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"]) - toc = time.time() - print('DONE (t={:0.2f}s).'.format(toc - tic)) diff --git a/src/data/dataset/coco_utils.py b/src/data/dataset/coco_utils.py index b8dd287..e8f38d4 100644 --- a/src/data/dataset/coco_utils.py +++ b/src/data/dataset/coco_utils.py @@ -9,8 +9,8 @@ import torch.utils.data import torchvision import torchvision.transforms.functional as TVF -from pycocotools import mask as coco_mask -from pycocotools.coco import COCO +import faster_coco_eval.core.mask as coco_mask +from faster_coco_eval import COCO def convert_coco_poly_to_mask(segmentations, height, width): diff --git a/src/data/dataset/csrc/fast_cocoeval/cocoeval.cc b/src/data/dataset/csrc/fast_cocoeval/cocoeval.cc deleted file mode 100644 index 5403062..0000000 --- a/src/data/dataset/csrc/fast_cocoeval/cocoeval.cc +++ /dev/null @@ -1,512 +0,0 @@ - -// The code is based on -// https://github.com/facebookresearch/detectron2/tree/main/detectron2/layers/csrc/cocoeval/ -// https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/metrics/fast_cocoeval/ext/ - -#include "cocoeval.h" -#include -#include -#include -#include - -using namespace pybind11::literals; - -// Sort detections from highest score to lowest, such that -// detection_instances[detection_sorted_indices[t]] >= -// detection_instances[detection_sorted_indices[t+1]]. Use stable_sort to match -// original COCO API -void SortInstancesByDetectionScore( - const std::vector& detection_instances, - std::vector* detection_sorted_indices) { - detection_sorted_indices->resize(detection_instances.size()); - std::iota( - detection_sorted_indices->begin(), detection_sorted_indices->end(), 0); - std::stable_sort( - detection_sorted_indices->begin(), - detection_sorted_indices->end(), - [&detection_instances](size_t j1, size_t j2) { - return detection_instances[j1].score > detection_instances[j2].score; - }); -} - -// Partition the ground truth objects based on whether or not to ignore them -// based on area -void SortInstancesByIgnore( - const std::array& area_range, - const std::vector& ground_truth_instances, - std::vector* ground_truth_sorted_indices, - std::vector* ignores) { - ignores->clear(); - ignores->reserve(ground_truth_instances.size()); - for (auto o : ground_truth_instances) { - ignores->push_back( - o.ignore || o.area < area_range[0] || o.area > area_range[1]); - } - - ground_truth_sorted_indices->resize(ground_truth_instances.size()); - std::iota( - ground_truth_sorted_indices->begin(), - ground_truth_sorted_indices->end(), - 0); - std::stable_sort( - ground_truth_sorted_indices->begin(), - ground_truth_sorted_indices->end(), - [&ignores](size_t j1, size_t j2) { - return (int)(*ignores)[j1] < (int)(*ignores)[j2]; - }); -} - -// For each IOU threshold, greedily match each detected instance to a ground -// truth instance (if possible) and store the results -void MatchDetectionsToGroundTruth( - const std::vector& detection_instances, - const std::vector& detection_sorted_indices, - const std::vector& ground_truth_instances, - const std::vector& ground_truth_sorted_indices, - const std::vector& ignores, - const std::vector>& ious, - const std::vector& iou_thresholds, - const std::array& area_range, - ImageEvaluation* results) { - // Initialize memory to store return data matches and ignore - const int num_iou_thresholds = iou_thresholds.size(); - const int num_ground_truth = ground_truth_sorted_indices.size(); - const int num_detections = detection_sorted_indices.size(); - std::vector ground_truth_matches( - num_iou_thresholds * num_ground_truth, 0); - std::vector& detection_matches = results->detection_matches; - std::vector& detection_ignores = results->detection_ignores; - std::vector& ground_truth_ignores = results->ground_truth_ignores; - detection_matches.resize(num_iou_thresholds * num_detections, 0); - detection_ignores.resize(num_iou_thresholds * num_detections, false); - ground_truth_ignores.resize(num_ground_truth); - for (auto g = 0; g < num_ground_truth; ++g) { - ground_truth_ignores[g] = ignores[ground_truth_sorted_indices[g]]; - } - - for (auto t = 0; t < num_iou_thresholds; ++t) { - for (auto d = 0; d < num_detections; ++d) { - // information about best match so far (match=-1 -> unmatched) - double best_iou = std::min(iou_thresholds[t], 1 - 1e-10); - int match = -1; - for (auto g = 0; g < num_ground_truth; ++g) { - // if this ground truth instance is already matched and not a - // crowd, it cannot be matched to another detection - if (ground_truth_matches[t * num_ground_truth + g] > 0 && - !ground_truth_instances[ground_truth_sorted_indices[g]].is_crowd) { - continue; - } - - // if detected instance matched to a regular ground truth - // instance, we can break on the first ground truth instance - // tagged as ignore (because they are sorted by the ignore tag) - if (match >= 0 && !ground_truth_ignores[match] && - ground_truth_ignores[g]) { - break; - } - - // if IOU overlap is the best so far, store the match appropriately - if (ious[d][ground_truth_sorted_indices[g]] >= best_iou) { - best_iou = ious[d][ground_truth_sorted_indices[g]]; - match = g; - } - } - // if match was made, store id of match for both detection and - // ground truth - if (match >= 0) { - detection_ignores[t * num_detections + d] = ground_truth_ignores[match]; - detection_matches[t * num_detections + d] = - ground_truth_instances[ground_truth_sorted_indices[match]].id; - ground_truth_matches[t * num_ground_truth + match] = - detection_instances[detection_sorted_indices[d]].id; - } - - // set unmatched detections outside of area range to ignore - const InstanceAnnotation& detection = - detection_instances[detection_sorted_indices[d]]; - detection_ignores[t * num_detections + d] = - detection_ignores[t * num_detections + d] || - (detection_matches[t * num_detections + d] == 0 && - (detection.area < area_range[0] || detection.area > area_range[1])); - } - } - - // store detection score results - results->detection_scores.resize(detection_sorted_indices.size()); - for (size_t d = 0; d < detection_sorted_indices.size(); ++d) { - results->detection_scores[d] = - detection_instances[detection_sorted_indices[d]].score; - } -} - -std::vector EvaluateImages( - const std::vector>& area_ranges, - int max_detections, - const std::vector& iou_thresholds, - const ImageCategoryInstances>& image_category_ious, - const ImageCategoryInstances& - image_category_ground_truth_instances, - const ImageCategoryInstances& - image_category_detection_instances) { - const int num_area_ranges = area_ranges.size(); - const int num_images = image_category_ground_truth_instances.size(); - const int num_categories = - image_category_ious.size() > 0 ? image_category_ious[0].size() : 0; - std::vector detection_sorted_indices; - std::vector ground_truth_sorted_indices; - std::vector ignores; - std::vector results_all( - num_images * num_area_ranges * num_categories); - - // Store results for each image, category, and area range combination. Results - // for each IOU threshold are packed into the same ImageEvaluation object - for (auto i = 0; i < num_images; ++i) { - for (auto c = 0; c < num_categories; ++c) { - const std::vector& ground_truth_instances = - image_category_ground_truth_instances[i][c]; - const std::vector& detection_instances = - image_category_detection_instances[i][c]; - - SortInstancesByDetectionScore( - detection_instances, &detection_sorted_indices); - if ((int)detection_sorted_indices.size() > max_detections) { - detection_sorted_indices.resize(max_detections); - } - - for (size_t a = 0; a < area_ranges.size(); ++a) { - SortInstancesByIgnore( - area_ranges[a], - ground_truth_instances, - &ground_truth_sorted_indices, - &ignores); - - MatchDetectionsToGroundTruth( - detection_instances, - detection_sorted_indices, - ground_truth_instances, - ground_truth_sorted_indices, - ignores, - image_category_ious[i][c], - iou_thresholds, - area_ranges[a], - &results_all - [c * num_area_ranges * num_images + a * num_images + i]); - } - } - } - - return results_all; -} - -// Convert a python list to a vector -template -std::vector list_to_vec(const py::list& l) { - std::vector v(py::len(l)); - for (int i = 0; i < (int)py::len(l); ++i) { - v[i] = l[i].cast(); - } - return v; -} - -// Helper function to Accumulate() -// Considers the evaluation results applicable to a particular category, area -// range, and max_detections parameter setting, which begin at -// evaluations[evaluation_index]. Extracts a sorted list of length n of all -// applicable detection instances concatenated across all images in the dataset, -// which are represented by the outputs evaluation_indices, detection_scores, -// image_detection_indices, and detection_sorted_indices--all of which are -// length n. evaluation_indices[i] stores the applicable index into -// evaluations[] for instance i, which has detection score detection_score[i], -// and is the image_detection_indices[i]'th of the list of detections -// for the image containing i. detection_sorted_indices[] defines a sorted -// permutation of the 3 other outputs -int BuildSortedDetectionList( - const std::vector& evaluations, - const int64_t evaluation_index, - const int64_t num_images, - const int max_detections, - std::vector* evaluation_indices, - std::vector* detection_scores, - std::vector* detection_sorted_indices, - std::vector* image_detection_indices) { - assert(evaluations.size() >= evaluation_index + num_images); - - // Extract a list of object instances of the applicable category, area - // range, and max detections requirements such that they can be sorted - image_detection_indices->clear(); - evaluation_indices->clear(); - detection_scores->clear(); - image_detection_indices->reserve(num_images * max_detections); - evaluation_indices->reserve(num_images * max_detections); - detection_scores->reserve(num_images * max_detections); - int num_valid_ground_truth = 0; - for (auto i = 0; i < num_images; ++i) { - const ImageEvaluation& evaluation = evaluations[evaluation_index + i]; - - for (int d = 0; - d < (int)evaluation.detection_scores.size() && d < max_detections; - ++d) { // detected instances - evaluation_indices->push_back(evaluation_index + i); - image_detection_indices->push_back(d); - detection_scores->push_back(evaluation.detection_scores[d]); - } - for (auto ground_truth_ignore : evaluation.ground_truth_ignores) { - if (!ground_truth_ignore) { - ++num_valid_ground_truth; - } - } - } - - // Sort detections by decreasing score, using stable sort to match - // python implementation - detection_sorted_indices->resize(detection_scores->size()); - std::iota( - detection_sorted_indices->begin(), detection_sorted_indices->end(), 0); - std::stable_sort( - detection_sorted_indices->begin(), - detection_sorted_indices->end(), - [&detection_scores](size_t j1, size_t j2) { - return (*detection_scores)[j1] > (*detection_scores)[j2]; - }); - - return num_valid_ground_truth; -} - -// Helper function to Accumulate() -// Compute a precision recall curve given a sorted list of detected instances -// encoded in evaluations, evaluation_indices, detection_scores, -// detection_sorted_indices, image_detection_indices (see -// BuildSortedDetectionList()). Using vectors precisions and recalls -// and temporary storage, output the results into precisions_out, recalls_out, -// and scores_out, which are large buffers containing many precion/recall curves -// for all possible parameter settings, with precisions_out_index and -// recalls_out_index defining the applicable indices to store results. -void ComputePrecisionRecallCurve( - const int64_t precisions_out_index, - const int64_t precisions_out_stride, - const int64_t recalls_out_index, - const std::vector& recall_thresholds, - const int iou_threshold_index, - const int num_iou_thresholds, - const int num_valid_ground_truth, - const std::vector& evaluations, - const std::vector& evaluation_indices, - const std::vector& detection_scores, - const std::vector& detection_sorted_indices, - const std::vector& image_detection_indices, - std::vector* precisions, - std::vector* recalls, - std::vector* precisions_out, - std::vector* scores_out, - std::vector* recalls_out) { - assert(recalls_out->size() > recalls_out_index); - - // Compute precision/recall for each instance in the sorted list of detections - int64_t true_positives_sum = 0, false_positives_sum = 0; - precisions->clear(); - recalls->clear(); - precisions->reserve(detection_sorted_indices.size()); - recalls->reserve(detection_sorted_indices.size()); - assert(!evaluations.empty() || detection_sorted_indices.empty()); - for (auto detection_sorted_index : detection_sorted_indices) { - const ImageEvaluation& evaluation = - evaluations[evaluation_indices[detection_sorted_index]]; - const auto num_detections = - evaluation.detection_matches.size() / num_iou_thresholds; - const auto detection_index = iou_threshold_index * num_detections + - image_detection_indices[detection_sorted_index]; - assert(evaluation.detection_matches.size() > detection_index); - assert(evaluation.detection_ignores.size() > detection_index); - const int64_t detection_match = - evaluation.detection_matches[detection_index]; - const bool detection_ignores = - evaluation.detection_ignores[detection_index]; - const auto true_positive = detection_match > 0 && !detection_ignores; - const auto false_positive = detection_match == 0 && !detection_ignores; - if (true_positive) { - ++true_positives_sum; - } - if (false_positive) { - ++false_positives_sum; - } - - const double recall = - static_cast(true_positives_sum) / num_valid_ground_truth; - recalls->push_back(recall); - const int64_t num_valid_detections = - true_positives_sum + false_positives_sum; - const double precision = num_valid_detections > 0 - ? static_cast(true_positives_sum) / num_valid_detections - : 0.0; - precisions->push_back(precision); - } - - (*recalls_out)[recalls_out_index] = !recalls->empty() ? recalls->back() : 0; - - for (int64_t i = static_cast(precisions->size()) - 1; i > 0; --i) { - if ((*precisions)[i] > (*precisions)[i - 1]) { - (*precisions)[i - 1] = (*precisions)[i]; - } - } - - // Sample the per instance precision/recall list at each recall threshold - for (size_t r = 0; r < recall_thresholds.size(); ++r) { - // first index in recalls >= recall_thresholds[r] - std::vector::iterator low = std::lower_bound( - recalls->begin(), recalls->end(), recall_thresholds[r]); - size_t precisions_index = low - recalls->begin(); - - const auto results_ind = precisions_out_index + r * precisions_out_stride; - assert(results_ind < precisions_out->size()); - assert(results_ind < scores_out->size()); - if (precisions_index < precisions->size()) { - (*precisions_out)[results_ind] = (*precisions)[precisions_index]; - (*scores_out)[results_ind] = - detection_scores[detection_sorted_indices[precisions_index]]; - } else { - (*precisions_out)[results_ind] = 0; - (*scores_out)[results_ind] = 0; - } - } -} -py::dict Accumulate( - const py::object& params, - const std::vector& evaluations) { - const std::vector recall_thresholds = - list_to_vec(params.attr("recThrs")); - const std::vector max_detections = - list_to_vec(params.attr("maxDets")); - const int num_iou_thresholds = py::len(params.attr("iouThrs")); - const int num_recall_thresholds = py::len(params.attr("recThrs")); - const int num_categories = params.attr("useCats").cast() == 1 - ? py::len(params.attr("catIds")) - : 1; - const int num_area_ranges = py::len(params.attr("areaRng")); - const int num_max_detections = py::len(params.attr("maxDets")); - const int num_images = py::len(params.attr("imgIds")); - - std::vector precisions_out( - num_iou_thresholds * num_recall_thresholds * num_categories * - num_area_ranges * num_max_detections, - -1); - std::vector recalls_out( - num_iou_thresholds * num_categories * num_area_ranges * - num_max_detections, - -1); - std::vector scores_out( - num_iou_thresholds * num_recall_thresholds * num_categories * - num_area_ranges * num_max_detections, - -1); - - // Consider the list of all detected instances in the entire dataset in one - // large list. evaluation_indices, detection_scores, - // image_detection_indices, and detection_sorted_indices all have the same - // length as this list, such that each entry corresponds to one detected - // instance - std::vector evaluation_indices; // indices into evaluations[] - std::vector detection_scores; // detection scores of each instance - std::vector detection_sorted_indices; // sorted indices of all - // instances in the dataset - std::vector - image_detection_indices; // indices into the list of detected instances in - // the same image as each instance - std::vector precisions, recalls; - - for (auto c = 0; c < num_categories; ++c) { - for (auto a = 0; a < num_area_ranges; ++a) { - for (auto m = 0; m < num_max_detections; ++m) { - // The COCO PythonAPI assumes evaluations[] (the return value of - // COCOeval::EvaluateImages() is one long list storing results for each - // combination of category, area range, and image id, with categories in - // the outermost loop and images in the innermost loop. - const int64_t evaluations_index = - c * num_area_ranges * num_images + a * num_images; - int num_valid_ground_truth = BuildSortedDetectionList( - evaluations, - evaluations_index, - num_images, - max_detections[m], - &evaluation_indices, - &detection_scores, - &detection_sorted_indices, - &image_detection_indices); - - if (num_valid_ground_truth == 0) { - continue; - } - - for (auto t = 0; t < num_iou_thresholds; ++t) { - // recalls_out is a flattened vectors representing a - // num_iou_thresholds X num_categories X num_area_ranges X - // num_max_detections matrix - const int64_t recalls_out_index = - t * num_categories * num_area_ranges * num_max_detections + - c * num_area_ranges * num_max_detections + - a * num_max_detections + m; - - // precisions_out and scores_out are flattened vectors - // representing a num_iou_thresholds X num_recall_thresholds X - // num_categories X num_area_ranges X num_max_detections matrix - const int64_t precisions_out_stride = - num_categories * num_area_ranges * num_max_detections; - const int64_t precisions_out_index = t * num_recall_thresholds * - num_categories * num_area_ranges * num_max_detections + - c * num_area_ranges * num_max_detections + - a * num_max_detections + m; - - ComputePrecisionRecallCurve( - precisions_out_index, - precisions_out_stride, - recalls_out_index, - recall_thresholds, - t, - num_iou_thresholds, - num_valid_ground_truth, - evaluations, - evaluation_indices, - detection_scores, - detection_sorted_indices, - image_detection_indices, - &precisions, - &recalls, - &precisions_out, - &scores_out, - &recalls_out); - } - } - } - } - - time_t rawtime; - struct tm local_time; - std::array buffer; - time(&rawtime); -#ifdef _WIN32 - localtime_s(&local_time, &rawtime); -#else - localtime_r(&rawtime, &local_time); -#endif - strftime( - buffer.data(), 200, "%Y-%m-%d %H:%num_max_detections:%S", &local_time); - return py::dict( - "params"_a = params, - "counts"_a = std::vector( - {num_iou_thresholds, - num_recall_thresholds, - num_categories, - num_area_ranges, - num_max_detections}), - "date"_a = buffer, - "precision"_a = precisions_out, - "recall"_a = recalls_out, - "scores"_a = scores_out); -} - -PYBIND11_MODULE(cocoeval_ext, m) { - m.def("COCOevalAccumulate", &Accumulate, "Accumulate"); - m.def("COCOevalEvaluateImages", &EvaluateImages, "EvaluateImages"); - py::class_(m, "InstanceAnnotation") - .def(py::init()); - py::class_(m, "ImageEvaluation") - .def(py::init<>()); -} diff --git a/src/data/dataset/csrc/fast_cocoeval/cocoeval.h b/src/data/dataset/csrc/fast_cocoeval/cocoeval.h deleted file mode 100644 index 5fcdd33..0000000 --- a/src/data/dataset/csrc/fast_cocoeval/cocoeval.h +++ /dev/null @@ -1,85 +0,0 @@ - -// The code is based on -// https://github.com/facebookresearch/detectron2/tree/main/detectron2/layers/csrc/cocoeval/ -// https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/metrics/fast_cocoeval/ext/ - -#pragma once - -#include -#include -#include -#include -#include - -namespace py = pybind11; - -// Annotation data for a single object instance in an image -struct InstanceAnnotation { - InstanceAnnotation( - uint64_t id, - double score, - double area, - bool is_crowd, - bool ignore) - : id{id}, score{score}, area{area}, is_crowd{is_crowd}, ignore{ignore} {} - uint64_t id; - double score = 0.; - double area = 0.; - bool is_crowd = false; - bool ignore = false; -}; - -// Stores intermediate results for evaluating detection results for a single -// image that has D detected instances and G ground truth instances. This stores -// matches between detected and ground truth instances -struct ImageEvaluation { - // For each of the D detected instances, the id of the matched ground truth - // instance, or 0 if unmatched - std::vector detection_matches; - - // The detection score of each of the D detected instances - std::vector detection_scores; - - // Marks whether or not each of G instances was ignored from evaluation (e.g., - // because it's outside area_range) - std::vector ground_truth_ignores; - - // Marks whether or not each of D instances was ignored from evaluation (e.g., - // because it's outside aRng) - std::vector detection_ignores; -}; - -template -using ImageCategoryInstances = std::vector>>; - -// C++ implementation of COCO API cocoeval.py::COCOeval.evaluateImg(). For each -// combination of image, category, area range settings, and IOU thresholds to -// evaluate, it matches detected instances to ground truth instances and stores -// the results into a vector of ImageEvaluation results, which will be -// interpreted by the COCOeval::Accumulate() function to produce precion-recall -// curves. The parameters of nested vectors have the following semantics: -// image_category_ious[i][c][d][g] is the intersection over union of the d'th -// detected instance and g'th ground truth instance of -// category category_ids[c] in image image_ids[i] -// image_category_ground_truth_instances[i][c] is a vector of ground truth -// instances in image image_ids[i] of category category_ids[c] -// image_category_detection_instances[i][c] is a vector of detected -// instances in image image_ids[i] of category category_ids[c] -std::vector EvaluateImages( - const std::vector>& area_ranges, // vector of 2-tuples - int max_detections, - const std::vector& iou_thresholds, - const ImageCategoryInstances>& image_category_ious, - const ImageCategoryInstances& - image_category_ground_truth_instances, - const ImageCategoryInstances& - image_category_detection_instances); - -// C++ implementation of COCOeval.accumulate(), which generates precision -// recall curves for each set of category, IOU threshold, detection area range, -// and max number of detections parameters. It is assumed that the parameter -// evaluations is the return value of the functon COCOeval::EvaluateImages(), -// which was called with the same parameter settings params -py::dict Accumulate( - const py::object& params, - const std::vector& evalutations); diff --git a/src/data/dataset/csrc/fast_cocoeval/setup.py b/src/data/dataset/csrc/fast_cocoeval/setup.py deleted file mode 100644 index f5663c8..0000000 --- a/src/data/dataset/csrc/fast_cocoeval/setup.py +++ /dev/null @@ -1,11 +0,0 @@ -from pybind11.setup_helpers import Pybind11Extension, build_ext -from setuptools import setup - -ext_modules = [Pybind11Extension("cocoeval_ext", ["cocoeval.cc"])] - -setup( - name="cocoeval_ext", - version="0.0.0", - ext_modules=ext_modules, - cmdclass={"build_ext": build_ext}, -) diff --git a/src/data/dataset/csrc/readme.md b/src/data/dataset/csrc/readme.md deleted file mode 100644 index 7f58aff..0000000 --- a/src/data/dataset/csrc/readme.md +++ /dev/null @@ -1,8 +0,0 @@ - -1. fast cocoeval -``` -# pip install pybind11 - -cd fast_cocoeval/ -python setup.py install -```