diff --git a/test/test_ops.py b/test/test_ops.py index 9cb0cddedf7..97875de4594 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1445,9 +1445,9 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh) -class TestBoxArea: +class TestBoxAreaXYXY: def area_check(self, box, expected, atol=1e-4): - out = ops.box_area(box) + out = ops.box_area(box, fmt="xyxy") torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) @@ -1472,12 +1472,47 @@ def test_float16_box(self): def test_box_area_jit(self): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) - expected = ops.box_area(box_tensor) + expected = ops.box_area(box_tensor, fmt="xyxy") scripted_fn = torch.jit.script(ops.box_area) scripted_area = scripted_fn(box_tensor) torch.testing.assert_close(scripted_area, expected) +class TestBoxAreaCXCYWH: + def area_check(self, box, expected, atol=1e-4): + out = ops.box_area(box, fmt="cxcywh") + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) + + @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) + def test_int_boxes(self, dtype): + box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype), + in_fmt="xyxy", out_fmt="cxcywh") + expected = torch.tensor([10000, 0], dtype=torch.int32) + self.area_check(box_tensor, expected) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) + def test_float_boxes(self, dtype): + box_tensor = ops.box_convert(torch.tensor(FLOAT_BOXES, dtype=dtype), in_fmt="xyxy", out_fmt="cxcywh") + expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype) + self.area_check(box_tensor, expected) + + def test_float16_box(self): + box_tensor = ops.box_convert(torch.tensor( + [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16 + ), in_fmt="xyxy", out_fmt="cxcywh") + + expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16) + self.area_check(box_tensor, expected, atol=0.01) + + def test_box_area_jit(self): + box_tensor = ops.box_convert(torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), + in_fmt="xyxy", out_fmt="cxcywh") + expected = ops.box_area(box_tensor, fmt="cxcywh") + scripted_fn = torch.jit.script(ops.box_area) + scripted_area = scripted_fn(box_tensor, fmt="cxcywh") + torch.testing.assert_close(scripted_area, expected) + + INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]] INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] FLOAT_BOXES = [ @@ -1486,6 +1521,14 @@ def test_box_area_jit(self): [279.2440, 197.9812, 1189.4746, 849.2019], ] +INT_BOXES_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100], [10, 10, 20, 20]] +INT_BOXES2_CXCYWH = [[50, 50, 100, 100], [25, 25, 50, 50], [250, 250, 100, 100]] +FLOAT_BOXES_CXCYWH = [ + [739.4324, 518.5154, 908.1572, 665.8793], + [738.8228, 519.9021, 907.3512, 662.3295], + [734.3593, 523.5916, 910.2306, 651.2207] +] + def gen_box(size, dtype=torch.float): xy1 = torch.rand((size, 2), dtype=dtype) @@ -1493,22 +1536,22 @@ def gen_box(size, dtype=torch.float): return torch.cat([xy1, xy2], axis=-1) -class TestIouBase: +class TestIouXYXYBase: @staticmethod def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): for dtype in dtypes: actual_box1 = torch.tensor(actual_box1, dtype=dtype) actual_box2 = torch.tensor(actual_box2, dtype=dtype) expected_box = torch.tensor(expected) - out = target_fn(actual_box1, actual_box2) + out = target_fn(actual_box1, actual_box2, fmt="xyxy") torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod - def _run_jit_test(target_fn: Callable, actual_box: list): + def _run_jit_test(target_fn: Callable, actual_box: List): box_tensor = torch.tensor(actual_box, dtype=torch.float) - expected = target_fn(box_tensor, box_tensor) + expected = target_fn(box_tensor, box_tensor, fmt="xyxy") scripted_fn = torch.jit.script(target_fn) - scripted_out = scripted_fn(box_tensor, box_tensor) + scripted_out = scripted_fn(box_tensor, box_tensor, fmt="xyxy") torch.testing.assert_close(scripted_out, expected) @staticmethod @@ -1518,19 +1561,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable): result = torch.zeros((N, M)) for i in range(N): for j in range(M): - result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="xyxy") return result @staticmethod def _run_cartesian_test(target_fn: Callable): boxes1 = gen_box(5) boxes2 = gen_box(7) - a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) - b = target_fn(boxes1, boxes2) + a = TestIouXYXYBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2, fmt="xyxy") torch.testing.assert_close(a, b) -class TestBoxIou(TestIouBase): +class TestBoxIouXYXY(TestIouXYXYBase): int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @@ -1552,6 +1595,101 @@ def test_iou_cartesian(self): self._run_cartesian_test(ops.box_iou) +class TestIouCXCYWHBase: + @staticmethod + def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): + for dtype in dtypes: + actual_box1 = torch.tensor(actual_box1, dtype=dtype) + actual_box2 = torch.tensor(actual_box2, dtype=dtype) + expected_box = torch.tensor(expected) + out = target_fn(actual_box1, actual_box2, fmt="cxcywh") + torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) + + @staticmethod + def _run_jit_test(target_fn: Callable, actual_box: List): + box_tensor = torch.tensor(actual_box, dtype=torch.float) + expected = target_fn(box_tensor, box_tensor, fmt="cxcywh") + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor, fmt="cxcywh") + torch.testing.assert_close(scripted_out, expected) + + @staticmethod + def _cartesian_product(boxes1, boxes2, target_fn: Callable): + N = boxes1.size(0) + M = boxes2.size(0) + result = torch.zeros((N, M)) + for i in range(N): + for j in range(M): + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0), fmt="cxcywh") + return result + + @staticmethod + def _run_cartesian_test(target_fn: Callable): + boxes1 = ops.box_convert(gen_box(5), in_fmt="xyxy", out_fmt="cxcywh") + boxes2 = ops.box_convert(gen_box(7), in_fmt="xyxy", out_fmt="cxcywh") + a = TestIouCXCYWHBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2, fmt="cxcywh") + torch.testing.assert_close(a, b) + + +class TestBoxIouCXCYWH(TestIouCXCYWHBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.04, 0.16, 0.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "actual_box1, actual_box2, dtypes, atol, expected", + [ + pytest.param(INT_BOXES_CXCYWH, INT_BOXES2_CXCYWH, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES_CXCYWH, FLOAT_BOXES_CXCYWH, [torch.float32, torch.float64], 1e-3, float_expected), + ], + ) + def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected): + self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.box_iou, INT_BOXES_CXCYWH) + + def test_iou_cartesian(self): + self._run_cartesian_test(ops.box_iou) + +class TestIouBase: + @staticmethod + def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected): + for dtype in dtypes: + actual_box1 = torch.tensor(actual_box1, dtype=dtype) + actual_box2 = torch.tensor(actual_box2, dtype=dtype) + expected_box = torch.tensor(expected) + out = target_fn(actual_box1, actual_box2) + torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) + + @staticmethod + def _run_jit_test(target_fn: Callable, actual_box: list): + box_tensor = torch.tensor(actual_box, dtype=torch.float) + expected = target_fn(box_tensor, box_tensor) + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor) + torch.testing.assert_close(scripted_out, expected) + + @staticmethod + def _cartesian_product(boxes1, boxes2, target_fn: Callable): + N = boxes1.size(0) + M = boxes2.size(0) + result = torch.zeros((N, M)) + for i in range(N): + for j in range(M): + result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0)) + return result + + @staticmethod + def _run_cartesian_test(target_fn: Callable): + boxes1 = gen_box(5) + boxes2 = gen_box(7) + a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn) + b = target_fn(boxes1, boxes2) + torch.testing.assert_close(a, b) + + class TestGeneralizedBoxIou(TestIouBase): int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 7fb8192e1cd..7a3b43b61cb 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -270,7 +270,30 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: return boxes -def box_area(boxes: Tensor) -> Tensor: +def box_area(boxes: Tensor, fmt: str = "xyxy") -> Tensor: + """ + Computes the area of a set of bounding boxes from a given format. + + Args: + boxes (Tensor[N, 4]): boxes for which the area will be computed. They + are expected to be in (x1, y1, x2, y2) format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy" + + Returns: + Tensor[N]: the area for each box + """ + if fmt == "xyxy": + boxes = box_area_xyxy(boxes=boxes) + elif fmt == "cxcywh": + boxes = box_area_cxcywh(boxes=boxes) + else: + raise ValueError(f"Unsupported Box Area Calculation for given fmt {fmt}") + + return boxes + + +def box_area_xyxy(boxes: Tensor) -> Tensor: """ Computes the area of a set of bounding boxes, which are specified by their (x1, y1, x2, y2) coordinates. @@ -284,16 +307,58 @@ def box_area(boxes: Tensor) -> Tensor: Tensor[N]: the area for each box """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_area) + _log_api_usage_once(box_area_xyxy) boxes = _upcast(boxes) return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) +def box_area_cxcywh(boxes: Tensor) -> Tensor: + """ + Computes the area of a set of bounding boxes, which are specified by their + (cx, cy, w, h) coordinates. + + Args: + boxes (Tensor[N, 4]): boxes for which the area will be computed. They + are expected to be in (cx, cy, w, h) format with + ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``. + + Returns: + Tensor[N]: the area for each box + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(box_area_cxcywh) + boxes = _upcast(boxes) + return boxes[:, 2] * boxes[:, 3] + + +def box_iou(boxes1: Tensor, boxes2: Tensor, fmt: str = "xyxy") -> Tensor: + """ + Return intersection-over-union (Jaccard index) between two sets of boxes from a given format. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + fmt (str): Format of given boxes. Supported formats are ['xyxy', 'cxcywh']. Default: "xyxy" + + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + """ + if fmt == "xyxy": + iou = box_iou_xyxy(boxes1=boxes1, boxes2=boxes2) + elif fmt == "cxcywh": + iou = box_iou_cxcywh(boxes1=boxes1, boxes2=boxes2) + else: + raise ValueError(f"Unsupported Box IoU Calculation for given fmt {fmt}") + + return iou + + # implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py # with slight modifications -def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: - area1 = box_area(boxes1) - area2 = box_area(boxes2) +def _box_inter_union_xyxy(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: + area1 = box_area(boxes1, fmt="xyxy") + area2 = box_area(boxes2, fmt="xyxy") lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] @@ -306,7 +371,7 @@ def _box_inter_union(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: return inter, union -def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: +def box_iou_xyxy(boxes1: Tensor, boxes2: Tensor) -> Tensor: """ Return intersection-over-union (Jaccard index) between two sets of boxes. @@ -321,8 +386,44 @@ def box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - _log_api_usage_once(box_iou) - inter, union = _box_inter_union(boxes1, boxes2) + _log_api_usage_once(box_iou_xyxy) + inter, union = _box_inter_union_xyxy(boxes1, boxes2) + iou = inter / union + return iou + + +def _box_inter_union_cxcywh(boxes1: Tensor, boxes2: Tensor) -> tuple[Tensor, Tensor]: + area1 = box_area(boxes1, fmt="cxcywh") + area2 = box_area(boxes2, fmt="cxcywh") + + lt = torch.max(boxes1[:, None, :2] - boxes1[:, None, 2:] / 2, boxes2[:, :2] - boxes2[:, 2:] / 2) # [N,M,2] + rb = torch.min(boxes1[:, None, :2] + boxes1[:, None, 2:] / 2, boxes2[:, :2] + boxes2[:, 2:] / 2) # [N,M,2] + + wh = _upcast(rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + return inter, union + + +def box_iou_cxcywh(boxes1: Tensor, boxes2: Tensor) -> Tensor: + """ + Return intersection-over-union (Jaccard index) between two sets of boxes. + + Both sets of boxes are expected to be in ``(cx, cy, w, h)`` format with + ``0 <= cx``, ``0 <= cy``, ``0 <= w`` and ``0 <= h``. + + Args: + boxes1 (Tensor[N, 4]): first set of boxes + boxes2 (Tensor[M, 4]): second set of boxes + + Returns: + Tensor[N, M]: the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(box_iou_cxcywh) + inter, union = _box_inter_union_cxcywh(boxes1, boxes2) iou = inter / union return iou @@ -346,7 +447,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(generalized_box_iou) - inter, union = _box_inter_union(boxes1, boxes2) + inter, union = _box_inter_union_xyxy(boxes1, boxes2) iou = inter / union lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])