Skip to content

Commit

Permalink
Merge pull request #28 from MiXaiLL76/faster-coco-eval
Browse files Browse the repository at this point in the history
feat: Replace pycocotools with faster-coco-eval.
  • Loading branch information
Peterande authored Nov 4, 2024
2 parents a8a15a9 + d5d977e commit 7240fb1
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 860 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torch>=2.0.1
torchvision>=0.15.2
pycocotools
faster-coco-eval>=1.6.5
PyYAML
tensorboard
scipy
Expand Down
6 changes: 4 additions & 2 deletions src/data/dataset/coco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import torchvision

from PIL import Image
from pycocotools import mask as coco_mask

import faster_coco_eval
import faster_coco_eval.core.mask as coco_mask
from ._dataset import DetDataset
from .._misc import convert_to_tv_tensor
from ...core import register

torchvision.disable_beta_transforms_warning()
faster_coco_eval.init_as_pycocotools()

__all__ = ['CocoDetection']

Expand Down
124 changes: 24 additions & 100 deletions src/data/dataset/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
import numpy as np
import torch

from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import pycocotools.mask as mask_util

from ...misc import dist_utils
from faster_coco_eval import COCO, COCOeval_faster
import faster_coco_eval.core.mask as mask_util
from ...core import register

from ...misc import dist_utils
__all__ = ['CocoEvaluator',]


Expand All @@ -26,20 +23,20 @@ class CocoEvaluator(object):
def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple))
coco_gt = copy.deepcopy(coco_gt)
self.coco_gt = coco_gt
self.coco_gt : COCO = coco_gt
self.iou_types = iou_types

self.coco_eval = {}
for iou_type in iou_types:
self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
self.coco_eval[iou_type] = COCOeval_faster(coco_gt, iouType=iou_type, print_function=print, separate_eval=True)

self.img_ids = []
self.eval_imgs = {k: [] for k in iou_types}

def cleanup(self):
self.coco_eval = {}
for iou_type in self.iou_types:
self.coco_eval[iou_type] = COCOeval(self.coco_gt, iouType=iou_type)
self.coco_eval[iou_type] = COCOeval_faster(self.coco_gt, iouType=iou_type, print_function=print, separate_eval=True)
self.img_ids = []
self.eval_imgs = {k: [] for k in self.iou_types}

Expand All @@ -50,23 +47,26 @@ def update(self, predictions):

for iou_type in self.iou_types:
results = self.prepare(predictions, iou_type)
coco_eval = self.coco_eval[iou_type]

# suppress pycocotools prints
with open(os.devnull, 'w') as devnull:
with contextlib.redirect_stdout(devnull):
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
coco_eval = self.coco_eval[iou_type]
coco_dt = self.coco_gt.loadRes(results) if results else COCO()
coco_eval.cocoDt = coco_dt
coco_eval.params.imgIds = list(img_ids)
coco_eval.evaluate()

coco_eval.cocoDt = coco_dt
coco_eval.params.imgIds = list(img_ids)
img_ids, eval_imgs = evaluate(coco_eval)

self.eval_imgs[iou_type].append(eval_imgs)
self.eval_imgs[iou_type].append(np.array(coco_eval._evalImgs_cpp).reshape(len(coco_eval.params.catIds), len(coco_eval.params.areaRng), len(coco_eval.params.imgIds)))

def synchronize_between_processes(self):
for iou_type in self.iou_types:
self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
img_ids, eval_imgs = merge(self.img_ids, self.eval_imgs[iou_type])

coco_eval = self.coco_eval[iou_type]
coco_eval.params.imgIds = img_ids
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
coco_eval._evalImgs_cpp = eval_imgs

def accumulate(self):
for coco_eval in self.coco_eval.values():
Expand Down Expand Up @@ -177,7 +177,6 @@ def convert_to_xywh(boxes):
xmin, ymin, xmax, ymax = boxes.unbind(1)
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)


def merge(img_ids, eval_imgs):
all_img_ids = dist_utils.all_gather(img_ids)
all_eval_imgs = dist_utils.all_gather(eval_imgs)
Expand All @@ -188,89 +187,14 @@ def merge(img_ids, eval_imgs):

merged_eval_imgs = []
for p in all_eval_imgs:
merged_eval_imgs.append(p)
merged_eval_imgs.extend(p)


merged_img_ids = np.array(merged_img_ids)
merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
merged_eval_imgs = np.concatenate(merged_eval_imgs, axis=2).ravel()
# merged_eval_imgs = np.array(merged_eval_imgs).T.ravel()

# keep only unique (and in sorted order) images
merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
merged_eval_imgs = merged_eval_imgs[..., idx]

return merged_img_ids, merged_eval_imgs


def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
img_ids, eval_imgs = merge(img_ids, eval_imgs)
img_ids = list(img_ids)
eval_imgs = list(eval_imgs.flatten())

coco_eval.evalImgs = eval_imgs
coco_eval.params.imgIds = img_ids
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)


#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################


# import io
# from contextlib import redirect_stdout
# def evaluate(imgs):
# with redirect_stdout(io.StringIO()):
# imgs.evaluate()
# return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))


def evaluate(self):
"""
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
"""
# tic = time.time()
# print('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if p.useSegm is not None:
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
# print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params = p

self._prepare()
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]

if p.iouType == 'segm' or p.iouType == 'bbox':
computeIoU = self.computeIoU
elif p.iouType == 'keypoints':
computeIoU = self.computeOks
self.ious = {
(imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds
for catId in catIds}

evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
evalImgs = [
evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
# this is NOT in the pycocotools code, but could be done outside
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
self._paramsEval = copy.deepcopy(self.params)
# toc = time.time()
# print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs

#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################

return merged_img_ids.tolist(), merged_eval_imgs.tolist()
139 changes: 0 additions & 139 deletions src/data/dataset/coco_fasteval.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/data/dataset/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch.utils.data
import torchvision
import torchvision.transforms.functional as TVF
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
import faster_coco_eval.core.mask as coco_mask
from faster_coco_eval import COCO


def convert_coco_poly_to_mask(segmentations, height, width):
Expand Down
Loading

0 comments on commit 7240fb1

Please sign in to comment.