Skip to content

Commit 14a8fbe

Browse files
committed
Dispatch style box_area and box_iou
1 parent 29c5147 commit 14a8fbe

File tree

4 files changed

+126
-50
lines changed

4 files changed

+126
-50
lines changed

docs/source/ops.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@ These utility functions perform various operations on bounding boxes.
5050
:template: function.rst
5151

5252
box_area
53-
box_area_center
5453
box_convert
5554
box_iou
56-
box_iou_center
5755
clip_boxes_to_image
5856
complete_box_iou
5957
distance_box_iou

test/test_ops.py

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,9 +1418,9 @@ def test_bbox_convert_jit(self):
14181418
torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
14191419

14201420

1421-
class TestBoxArea:
1421+
class TestBoxAreaXYXY:
14221422
def area_check(self, box, expected, atol=1e-4):
1423-
out = ops.box_area(box)
1423+
out = ops.box_area(box, fmt="xyxy")
14241424
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)
14251425

14261426
@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
@@ -1445,15 +1445,15 @@ def test_float16_box(self):
14451445

14461446
def test_box_area_jit(self):
14471447
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
1448-
expected = ops.box_area(box_tensor)
1448+
expected = ops.box_area(box_tensor, fmt="xyxy")
14491449
scripted_fn = torch.jit.script(ops.box_area)
14501450
scripted_area = scripted_fn(box_tensor)
14511451
torch.testing.assert_close(scripted_area, expected)
14521452

14531453

1454-
class TestBoxAreaCenter:
1454+
class TestBoxAreaCXCYWH:
14551455
def area_check(self, box, expected, atol=1e-4):
1456-
out = ops.box_area_center(box)
1456+
out = ops.box_area(box, fmt="cxcywh")
14571457
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol)
14581458

14591459
@pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64])
@@ -1480,9 +1480,9 @@ def test_float16_box(self):
14801480
def test_box_area_jit(self):
14811481
box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float),
14821482
in_fmt="xyxy", out_fmt="cxcywh")
1483-
expected = ops.box_area_center(box_tensor)
1484-
scripted_fn = torch.jit.script(ops.box_area_center)
1485-
scripted_area = scripted_fn(box_tensor)
1483+
expected = ops.box_area(box_tensor, fmt="cxcywh")
1484+
scripted_fn = torch.jit.script(ops.box_area)
1485+
scripted_area = scripted_fn(box_tensor, fmt="cxcywh")
14861486
torch.testing.assert_close(scripted_area, expected)
14871487

14881488

@@ -1509,22 +1509,22 @@ def gen_box(size, dtype=torch.float):
15091509
return torch.cat([xy1, xy2], axis=-1)
15101510

15111511

1512-
class TestIouBase:
1512+
class TestIouXYXYBase:
15131513
@staticmethod
15141514
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
15151515
for dtype in dtypes:
15161516
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
15171517
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
15181518
expected_box = torch.tensor(expected)
1519-
out = target_fn(actual_box1, actual_box2)
1519+
out = target_fn(actual_box1, actual_box2, fmt="xyxy")
15201520
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
15211521

15221522
@staticmethod
15231523
def _run_jit_test(target_fn: Callable, actual_box: List):
15241524
box_tensor = torch.tensor(actual_box, dtype=torch.float)
1525-
expected = target_fn(box_tensor, box_tensor)
1525+
expected = target_fn(box_tensor, box_tensor, fmt="xyxy")
15261526
scripted_fn = torch.jit.script(target_fn)
1527-
scripted_out = scripted_fn(box_tensor, box_tensor)
1527+
scripted_out = scripted_fn(box_tensor, box_tensor, fmt="xyxy")
15281528
torch.testing.assert_close(scripted_out, expected)
15291529

15301530
@staticmethod
@@ -1534,19 +1534,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
15341534
result = torch.zeros((N, M))
15351535
for i in range(N):
15361536
for j in range(M):
1537-
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
1537+
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="xyxy")
15381538
return result
15391539

15401540
@staticmethod
15411541
def _run_cartesian_test(target_fn: Callable):
15421542
boxes1 = gen_box(5)
15431543
boxes2 = gen_box(7)
1544-
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
1545-
b = target_fn(boxes1, boxes2)
1544+
a = TestIouXYXYBase._cartesian_product(boxes1, boxes2, target_fn)
1545+
b = target_fn(boxes1, boxes2, fmt="xyxy")
15461546
torch.testing.assert_close(a, b)
15471547

15481548

1549-
class TestBoxIou(TestIouBase):
1549+
class TestBoxIouXYXY(TestIouXYXYBase):
15501550
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
15511551
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
15521552

@@ -1568,22 +1568,22 @@ def test_iou_cartesian(self):
15681568
self._run_cartesian_test(ops.box_iou)
15691569

15701570

1571-
class TestIouCenterBase:
1571+
class TestIouCXCYWHBase:
15721572
@staticmethod
15731573
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
15741574
for dtype in dtypes:
15751575
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
15761576
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
15771577
expected_box = torch.tensor(expected)
1578-
out = target_fn(actual_box1, actual_box2)
1578+
out = target_fn(actual_box1, actual_box2, fmt="cxcywh")
15791579
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
15801580

15811581
@staticmethod
15821582
def _run_jit_test(target_fn: Callable, actual_box: List):
15831583
box_tensor = torch.tensor(actual_box, dtype=torch.float)
1584-
expected = target_fn(box_tensor, box_tensor)
1584+
expected = target_fn(box_tensor, box_tensor, fmt="cxcywh")
15851585
scripted_fn = torch.jit.script(target_fn)
1586-
scripted_out = scripted_fn(box_tensor, box_tensor)
1586+
scripted_out = scripted_fn(box_tensor, box_tensor, fmt="cxcywh")
15871587
torch.testing.assert_close(scripted_out, expected)
15881588

15891589
@staticmethod
@@ -1593,19 +1593,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
15931593
result = torch.zeros((N, M))
15941594
for i in range(N):
15951595
for j in range(M):
1596-
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
1596+
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="cxcywh")
15971597
return result
15981598

15991599
@staticmethod
16001600
def _run_cartesian_test(target_fn: Callable):
16011601
boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh")
16021602
boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh")
1603-
a = TestIouCenterBase._cartesian_product(boxes1, boxes2, target_fn)
1604-
b = target_fn(boxes1, boxes2)
1603+
a = TestIouCXCYWHBase._cartesian_product(boxes1, boxes2, target_fn)
1604+
b = target_fn(boxes1, boxes2, fmt="cxcywh")
16051605
torch.testing.assert_close(a, b)
16061606

16071607

1608-
class TestBoxIouCenter(TestIouBase):
1608+
class TestBoxIouCXCYWH(TestIouCXCYWHBase):
16091609
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]]
16101610
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
16111611

@@ -1618,13 +1618,49 @@ class TestBoxIouCenter(TestIouBase):
16181618
],
16191619
)
16201620
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1621-
self._run_test(ops.box_iou_center, actual_box1, actual_box2, dtypes, atol, expected)
1621+
self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
16221622

16231623
def test_iou_jit(self):
1624-
self._run_jit_test(ops.box_iou_center, INT_BOXES_CXCYWH)
1624+
self._run_jit_test(ops.box_iou, INT_BOXES_CXCYWH)
16251625

16261626
def test_iou_cartesian(self):
1627-
self._run_cartesian_test(ops.box_iou_center)
1627+
self._run_cartesian_test(ops.box_iou)
1628+
1629+
class TestIouBase:
1630+
@staticmethod
1631+
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
1632+
for dtype in dtypes:
1633+
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
1634+
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
1635+
expected_box = torch.tensor(expected)
1636+
out = target_fn(actual_box1, actual_box2)
1637+
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
1638+
1639+
@staticmethod
1640+
def _run_jit_test(target_fn: Callable, actual_box: List):
1641+
box_tensor = torch.tensor(actual_box, dtype=torch.float)
1642+
expected = target_fn(box_tensor, box_tensor)
1643+
scripted_fn = torch.jit.script(target_fn)
1644+
scripted_out = scripted_fn(box_tensor, box_tensor)
1645+
torch.testing.assert_close(scripted_out, expected)
1646+
1647+
@staticmethod
1648+
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
1649+
N = boxes1.size(0)
1650+
M = boxes2.size(0)
1651+
result = torch.zeros((N, M))
1652+
for i in range(N):
1653+
for j in range(M):
1654+
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
1655+
return result
1656+
1657+
@staticmethod
1658+
def _run_cartesian_test(target_fn: Callable):
1659+
boxes1 = gen_box(5)
1660+
boxes2 = gen_box(7)
1661+
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
1662+
b = target_fn(boxes1, boxes2)
1663+
torch.testing.assert_close(a, b)
16281664

16291665

16301666
class TestGeneralizedBoxIou(TestIouBase):

torchvision/ops/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
from .boxes import (
33
batched_nms,
44
box_area,
5-
box_area_center,
65
box_convert,
76
box_iou,
8-
box_iou_center,
97
clip_boxes_to_image,
108
complete_box_iou,
119
distance_box_iou,
@@ -42,9 +40,7 @@
4240
"clip_boxes_to_image",
4341
"box_convert",
4442
"box_area",
45-
"box_area_center",
4643
"box_iou",
47-
"box_iou_center",
4844
"generalized_box_iou",
4945
"distance_box_iou",
5046
"complete_box_iou",

torchvision/ops/boxes.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,30 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
272272
return boxes
273273

274274

275-
def box_area(boxes: Tensor) -> Tensor:
275+
def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor:
276+
"""
277+
Computes the area of a set of bounding boxes from a given format.
278+
279+
Args:
280+
boxes (Tensor[N, 4]): boxes for which the area will be computed. They
281+
are expected to be in (x1, y1, x2, y2) format with
282+
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
283+
fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy"
284+
285+
Returns:
286+
Tensor[N]: the area for each box
287+
"""
288+
if fmt == "xyxy":
289+
boxes = box_area_xyxy(boxes=boxes)
290+
elif fmt == "cxcywh":
291+
boxes = box_area_cxcywh(boxes=boxes)
292+
else:
293+
raise ValueError(f"Unsupported Box Area Calculation for given fmt {fmt}")
294+
295+
return boxes
296+
297+
298+
def box_area_xyxy(boxes: Tensor) -> Tensor:
276299
"""
277300
Computes the area of a set of bounding boxes, which are specified by their
278301
(x1, y1, x2, y2) coordinates.
@@ -286,12 +309,12 @@ def box_area(boxes: Tensor) -> Tensor:
286309
Tensor[N]: the area for each box
287310
"""
288311
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
289-
_log_api_usage_once(box_area)
312+
_log_api_usage_once(box_area_xyxy)
290313
boxes = _upcast(boxes)
291314
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
292315

293316

294-
def box_area_center(boxes: Tensor) -> Tensor:
317+
def box_area_cxcywh(boxes: Tensor) -> Tensor:
295318
"""
296319
Computes the area of a set of bounding boxes, which are specified by their
297320
(cx, cy, w, h) coordinates.
@@ -305,16 +328,39 @@ def box_area_center(boxes: Tensor) -> Tensor:
305328
Tensor[N]: the area for each box
306329
"""
307330
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
308-
_log_api_usage_once(box_area_center)
331+
_log_api_usage_once(box_area_cxcywh)
309332
boxes = _upcast(boxes)
310333
return boxes[:, 2] * boxes[:, 3]
311334

312335

336+
def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor:
337+
"""
338+
Return intersection-over-union (Jaccard index) between two sets of boxes from a given format.
339+
340+
Args:
341+
boxes1 (Tensor[N, 4]): first set of boxes
342+
boxes2 (Tensor[M, 4]): second set of boxes
343+
fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy"
344+
345+
346+
Returns:
347+
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
348+
"""
349+
if fmt == "xyxy":
350+
iou = box_iou_xyxy(boxes1=boxes1, boxes2=boxes2)
351+
elif fmt == "cxcywh":
352+
iou = box_iou_cxcywh(boxes1=boxes1, boxes2=boxes2)
353+
else:
354+
raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}")
355+
356+
return iou
357+
358+
313359
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
314360
# with slight modifications
315-
def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
316-
area1 = box_area(boxes1)
317-
area2 = box_area(boxes2)
361+
def _box_inter_union_xyxy(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]:
362+
area1 = box_area(boxes1, fmt="xyxy")
363+
area2 = box_area(boxes2, fmt="xyxy")
318364

319365
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
320366
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
@@ -327,7 +373,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
327373
return inter, union
328374

329375

330-
def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
376+
def box_iou_xyxy(boxes1: Tensor, boxes2: Tensor) -> Tensor:
331377
"""
332378
Return intersection-over-union (Jaccard index) between two sets of boxes.
333379
@@ -342,15 +388,15 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
342388
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
343389
"""
344390
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
345-
_log_api_usage_once(box_iou)
346-
inter, union = _box_inter_union(boxes1, boxes2)
391+
_log_api_usage_once(box_iou_xyxy)
392+
inter, union = _box_inter_union_xyxy(boxes1, boxes2)
347393
iou = inter / union
348394
return iou
349395

350396

351-
def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
352-
area1 = box_area_center(boxes1)
353-
area2 = box_area_center(boxes2)
397+
def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Tensor]:
398+
area1 = box_area(boxes1, fmt="cxcywh")
399+
area2 = box_area(boxes2, fmt="cxcywh")
354400

355401
lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2]
356402
rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2]
@@ -363,7 +409,7 @@ def _box_inter_union_center(boxes1: Tensor, boxes2: Tensor) -> Tuple[Tensor, Ten
363409
return inter, union
364410

365411

366-
def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor:
412+
def box_iou_cxcywh(boxes1: Tensor, boxes2: Tensor) -> Tensor:
367413
"""
368414
Return intersection-over-union (Jaccard index) between two sets of boxes.
369415
@@ -378,8 +424,8 @@ def box_iou_center(boxes1: Tensor, boxes2: Tensor) -> Tensor:
378424
Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
379425
"""
380426
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
381-
_log_api_usage_once(box_iou_center)
382-
inter, union = _box_inter_union_center(boxes1, boxes2)
427+
_log_api_usage_once(box_iou_cxcywh)
428+
inter, union = _box_inter_union_cxcywh(boxes1, boxes2)
383429
iou = inter / union
384430
return iou
385431

@@ -403,7 +449,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:
403449
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
404450
_log_api_usage_once(generalized_box_iou)
405451

406-
inter, union = _box_inter_union(boxes1, boxes2)
452+
inter, union = _box_inter_union_xyxy(boxes1, boxes2)
407453
iou = inter / union
408454

409455
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])

0 commit comments

Comments
 (0)