@@ -1418,9 +1418,9 @@ def test_bbox_convert_jit(self):
1418
1418
torch .testing .assert_close (scripted_cxcywh , box_cxcywh )
1419
1419
1420
1420
1421
- class TestBoxArea :
1421
+ class TestBoxAreaXYXY :
1422
1422
def area_check (self , box , expected , atol = 1e-4 ):
1423
- out = ops .box_area (box )
1423
+ out = ops .box_area (box , fmt = "xyxy" )
1424
1424
torch .testing .assert_close (out , expected , rtol = 0.0 , check_dtype = False , atol = atol )
1425
1425
1426
1426
@pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ])
@@ -1445,15 +1445,15 @@ def test_float16_box(self):
1445
1445
1446
1446
def test_box_area_jit (self ):
1447
1447
box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float )
1448
- expected = ops .box_area (box_tensor )
1448
+ expected = ops .box_area (box_tensor , fmt = "xyxy" )
1449
1449
scripted_fn = torch .jit .script (ops .box_area )
1450
1450
scripted_area = scripted_fn (box_tensor )
1451
1451
torch .testing .assert_close (scripted_area , expected )
1452
1452
1453
1453
1454
- class TestBoxAreaCenter :
1454
+ class TestBoxAreaCXCYWH :
1455
1455
def area_check (self , box , expected , atol = 1e-4 ):
1456
- out = ops .box_area_center (box )
1456
+ out = ops .box_area (box , fmt = "cxcywh" )
1457
1457
torch .testing .assert_close (out , expected , rtol = 0.0 , check_dtype = False , atol = atol )
1458
1458
1459
1459
@pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ])
@@ -1480,9 +1480,9 @@ def test_float16_box(self):
1480
1480
def test_box_area_jit (self ):
1481
1481
box_tensor = ops .box_convert (torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ),
1482
1482
in_fmt = "xyxy" , out_fmt = "cxcywh" )
1483
- expected = ops .box_area_center (box_tensor )
1484
- scripted_fn = torch .jit .script (ops .box_area_center )
1485
- scripted_area = scripted_fn (box_tensor )
1483
+ expected = ops .box_area (box_tensor , fmt = "cxcywh" )
1484
+ scripted_fn = torch .jit .script (ops .box_area )
1485
+ scripted_area = scripted_fn (box_tensor , fmt = "cxcywh" )
1486
1486
torch .testing .assert_close (scripted_area , expected )
1487
1487
1488
1488
@@ -1509,22 +1509,22 @@ def gen_box(size, dtype=torch.float):
1509
1509
return torch .cat ([xy1 , xy2 ], axis = - 1 )
1510
1510
1511
1511
1512
- class TestIouBase :
1512
+ class TestIouXYXYBase :
1513
1513
@staticmethod
1514
1514
def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
1515
1515
for dtype in dtypes :
1516
1516
actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
1517
1517
actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
1518
1518
expected_box = torch .tensor (expected )
1519
- out = target_fn (actual_box1 , actual_box2 )
1519
+ out = target_fn (actual_box1 , actual_box2 , fmt = "xyxy" )
1520
1520
torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
1521
1521
1522
1522
@staticmethod
1523
1523
def _run_jit_test (target_fn : Callable , actual_box : List ):
1524
1524
box_tensor = torch .tensor (actual_box , dtype = torch .float )
1525
- expected = target_fn (box_tensor , box_tensor )
1525
+ expected = target_fn (box_tensor , box_tensor , fmt = "xyxy" )
1526
1526
scripted_fn = torch .jit .script (target_fn )
1527
- scripted_out = scripted_fn (box_tensor , box_tensor )
1527
+ scripted_out = scripted_fn (box_tensor , box_tensor , fmt = "xyxy" )
1528
1528
torch .testing .assert_close (scripted_out , expected )
1529
1529
1530
1530
@staticmethod
@@ -1534,19 +1534,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
1534
1534
result = torch .zeros ((N , M ))
1535
1535
for i in range (N ):
1536
1536
for j in range (M ):
1537
- result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1537
+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ), fmt = "xyxy" )
1538
1538
return result
1539
1539
1540
1540
@staticmethod
1541
1541
def _run_cartesian_test (target_fn : Callable ):
1542
1542
boxes1 = gen_box (5 )
1543
1543
boxes2 = gen_box (7 )
1544
- a = TestIouBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1545
- b = target_fn (boxes1 , boxes2 )
1544
+ a = TestIouXYXYBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1545
+ b = target_fn (boxes1 , boxes2 , fmt = "xyxy" )
1546
1546
torch .testing .assert_close (a , b )
1547
1547
1548
1548
1549
- class TestBoxIou ( TestIouBase ):
1549
+ class TestBoxIouXYXY ( TestIouXYXYBase ):
1550
1550
int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [0.0625 , 0.25 , 0.0 ]]
1551
1551
float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
1552
1552
@@ -1568,22 +1568,22 @@ def test_iou_cartesian(self):
1568
1568
self ._run_cartesian_test (ops .box_iou )
1569
1569
1570
1570
1571
- class TestIouCenterBase :
1571
+ class TestIouCXCYWHBase :
1572
1572
@staticmethod
1573
1573
def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
1574
1574
for dtype in dtypes :
1575
1575
actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
1576
1576
actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
1577
1577
expected_box = torch .tensor (expected )
1578
- out = target_fn (actual_box1 , actual_box2 )
1578
+ out = target_fn (actual_box1 , actual_box2 , fmt = "cxcywh" )
1579
1579
torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
1580
1580
1581
1581
@staticmethod
1582
1582
def _run_jit_test (target_fn : Callable , actual_box : List ):
1583
1583
box_tensor = torch .tensor (actual_box , dtype = torch .float )
1584
- expected = target_fn (box_tensor , box_tensor )
1584
+ expected = target_fn (box_tensor , box_tensor , fmt = "cxcywh" )
1585
1585
scripted_fn = torch .jit .script (target_fn )
1586
- scripted_out = scripted_fn (box_tensor , box_tensor )
1586
+ scripted_out = scripted_fn (box_tensor , box_tensor , fmt = "cxcywh" )
1587
1587
torch .testing .assert_close (scripted_out , expected )
1588
1588
1589
1589
@staticmethod
@@ -1593,19 +1593,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
1593
1593
result = torch .zeros ((N , M ))
1594
1594
for i in range (N ):
1595
1595
for j in range (M ):
1596
- result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1596
+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ), fmt = "cxcywh" )
1597
1597
return result
1598
1598
1599
1599
@staticmethod
1600
1600
def _run_cartesian_test (target_fn : Callable ):
1601
1601
boxes1 = ops .box_convert (gen_box (5 ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1602
1602
boxes2 = ops .box_convert (gen_box (7 ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1603
- a = TestIouCenterBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1604
- b = target_fn (boxes1 , boxes2 )
1603
+ a = TestIouCXCYWHBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1604
+ b = target_fn (boxes1 , boxes2 , fmt = "cxcywh" )
1605
1605
torch .testing .assert_close (a , b )
1606
1606
1607
1607
1608
- class TestBoxIouCenter ( TestIouBase ):
1608
+ class TestBoxIouCXCYWH ( TestIouCXCYWHBase ):
1609
1609
int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [0.04 , 0.16 , 0.0 ]]
1610
1610
float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
1611
1611
@@ -1618,13 +1618,49 @@ class TestBoxIouCenter(TestIouBase):
1618
1618
],
1619
1619
)
1620
1620
def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1621
- self ._run_test (ops .box_iou_center , actual_box1 , actual_box2 , dtypes , atol , expected )
1621
+ self ._run_test (ops .box_iou , actual_box1 , actual_box2 , dtypes , atol , expected )
1622
1622
1623
1623
def test_iou_jit (self ):
1624
- self ._run_jit_test (ops .box_iou_center , INT_BOXES_CXCYWH )
1624
+ self ._run_jit_test (ops .box_iou , INT_BOXES_CXCYWH )
1625
1625
1626
1626
def test_iou_cartesian (self ):
1627
- self ._run_cartesian_test (ops .box_iou_center )
1627
+ self ._run_cartesian_test (ops .box_iou )
1628
+
1629
+ class TestIouBase :
1630
+ @staticmethod
1631
+ def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
1632
+ for dtype in dtypes :
1633
+ actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
1634
+ actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
1635
+ expected_box = torch .tensor (expected )
1636
+ out = target_fn (actual_box1 , actual_box2 )
1637
+ torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
1638
+
1639
+ @staticmethod
1640
+ def _run_jit_test (target_fn : Callable , actual_box : List ):
1641
+ box_tensor = torch .tensor (actual_box , dtype = torch .float )
1642
+ expected = target_fn (box_tensor , box_tensor )
1643
+ scripted_fn = torch .jit .script (target_fn )
1644
+ scripted_out = scripted_fn (box_tensor , box_tensor )
1645
+ torch .testing .assert_close (scripted_out , expected )
1646
+
1647
+ @staticmethod
1648
+ def _cartesian_product (boxes1 , boxes2 , target_fn : Callable ):
1649
+ N = boxes1 .size (0 )
1650
+ M = boxes2 .size (0 )
1651
+ result = torch .zeros ((N , M ))
1652
+ for i in range (N ):
1653
+ for j in range (M ):
1654
+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1655
+ return result
1656
+
1657
+ @staticmethod
1658
+ def _run_cartesian_test (target_fn : Callable ):
1659
+ boxes1 = gen_box (5 )
1660
+ boxes2 = gen_box (7 )
1661
+ a = TestIouBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1662
+ b = target_fn (boxes1 , boxes2 )
1663
+ torch .testing .assert_close (a , b )
1628
1664
1629
1665
1630
1666
class TestGeneralizedBoxIou (TestIouBase ):
0 commit comments