From e2c45c4c486ffbe28ea7e3db922448fae7193ede Mon Sep 17 00:00:00 2001 From: hachoj Date: Sun, 8 Jun 2025 17:49:25 -0400 Subject: [PATCH 1/3] fixed centering of anchors on grid cells Signed-off-by: hachoj --- monai/apps/detection/utils/anchor_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py index cbde3ebae9..c3cbe5f14c 100644 --- a/monai/apps/detection/utils/anchor_utils.py +++ b/monai/apps/detection/utils/anchor_utils.py @@ -253,7 +253,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]]) # compute anchor centers regarding to the image. # shifts_centers is [x_center, y_center] or [x_center, y_center, z_center] shifts_centers = [ - torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + stride[axis] // 2 for axis in range(self.spatial_dims) ] From 34d73e5e4db2262b8d0081e99b3609c0146f9c78 Mon Sep 17 00:00:00 2001 From: hachoj Date: Sun, 8 Jun 2025 17:52:48 -0400 Subject: [PATCH 2/3] updated unit tests to match new functionality Signed-off-by: hachoj --- tests/apps/detection/utils/test_anchor_box.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/apps/detection/utils/test_anchor_box.py b/tests/apps/detection/utils/test_anchor_box.py index 7543c84ed9..13abac6447 100644 --- a/tests/apps/detection/utils/test_anchor_box.py +++ b/tests/apps/detection/utils/test_anchor_box.py @@ -57,15 +57,18 @@ def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes): grid_sizes = [[2, 2], [1, 1]] strides = [[torch.tensor(1), torch.tensor(2)], [torch.tensor(2), torch.tensor(4)]] - for a, a_f in zip(anchor.grid_anchors(grid_sizes, strides), anchor_ref.grid_anchors(grid_sizes, strides)): - assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3) - images = torch.rand(image_shape) - feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes) - result = anchor(images, feature_maps) - result_ref = anchor_ref(image_list.ImageList(images, ([123, 122],)), feature_maps) - for a, a_f in zip(result, result_ref): - assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1) + monai_anchors = anchor.grid_anchors(grid_sizes, strides) + torchvision_anchors = anchor_ref.grid_anchors(grid_sizes, strides) + + for a, a_f, s in zip(monai_anchors, torchvision_anchors, strides): + stride_y, stride_x = s + + offset_x = stride_x // 2 + offset_y = stride_y // 2 + offset = torch.tensor([offset_x, offset_y, offset_x, offset_y], dtype=a_f.dtype, device=a_f.device) + + assert_allclose(a, a_f + offset, type_test=True, device_test=False, atol=1e-3) @parameterized.expand(TEST_CASES_2D) def test_script_2d(self, input_param, image_shape, feature_maps_shapes): From a17ce8b995638e81bc254b6514de73f00c753655 Mon Sep 17 00:00:00 2001 From: hachoj Date: Sun, 8 Jun 2025 22:36:07 -0400 Subject: [PATCH 3/3] further update to unit tests to model test new functionlaity Signed-off-by: hachoj --- tests/apps/detection/utils/test_anchor_box.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/apps/detection/utils/test_anchor_box.py b/tests/apps/detection/utils/test_anchor_box.py index 13abac6447..22bf350f0b 100644 --- a/tests/apps/detection/utils/test_anchor_box.py +++ b/tests/apps/detection/utils/test_anchor_box.py @@ -45,9 +45,9 @@ class TestAnchorGenerator(unittest.TestCase): @parameterized.expand(TEST_CASES_2D) def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes): torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils") - image_list, _ = optional_import("torchvision.models.detection.image_list") - # test it behaves the same with torchvision for 2d + # test it behaves for new functionality of centered anchors + # pytorch does not follow this functionality anchor = AnchorGenerator(**input_param, indexing="xy") anchor_ref = torch_anchor_utils.AnchorGenerator(**input_param) for a, a_f in zip(anchor.cell_anchors, anchor_ref.cell_anchors):