@@ -45,30 +45,30 @@ def test__prepare_output():
4545    metric  =  MeanAveragePrecision ()
4646
4747    metric ._type  =  "binary" 
48-     scores , y  =  metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 )). bool () ))
48+     scores , y  =  metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 ))))
4949    assert  scores .shape  ==  y .shape  ==  (1 , 120 )
5050
5151    metric ._type  =  "multiclass" 
5252    scores , y  =  metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 4 , (5 , 3 , 2 ))))
5353    assert  scores .shape  ==  (4 , 30 ) and  y .shape  ==  (30 ,)
5454
5555    metric ._type  =  "multilabel" 
56-     scores , y  =  metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 )). bool () ))
56+     scores , y  =  metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 ))))
5757    assert  scores .shape  ==  y .shape  ==  (4 , 30 )
5858
5959
6060def  test_update ():
6161    metric  =  MeanAveragePrecision ()
6262    assert  len (metric ._y_pred ) ==  len (metric ._y_true ) ==  0 
63-     metric .update ((torch .rand ((5 , 4 )), torch .randint (0 , 2 , (5 , 4 )). bool () ))
63+     metric .update ((torch .rand ((5 , 4 )), torch .randint (0 , 2 , (5 , 4 ))))
6464    assert  len (metric ._y_pred ) ==  len (metric ._y_true ) ==  1 
6565
6666
6767def  test__compute_recall_and_precision ():
6868    m  =  MeanAveragePrecision ()
6969
7070    scores  =  torch .rand ((50 ,))
71-     y_true  =  torch .randint (0 , 2 , (50 ,)). bool () 
71+     y_true  =  torch .randint (0 , 2 , (50 ,))
7272    precision , recall , _  =  precision_recall_curve (y_true .numpy (), scores .numpy ())
7373    P  =  y_true .sum (dim = - 1 )
7474    ignite_recall , ignite_precision  =  m ._compute_recall_and_precision (y_true , scores , P )
@@ -77,7 +77,7 @@ def test__compute_recall_and_precision():
7777
7878    # When there's no actual positive. Numpy expectedly raises warning. 
7979    scores  =  torch .rand ((50 ,))
80-     y_true  =  torch .zeros ((50 ,)). bool () 
80+     y_true  =  torch .zeros ((50 ,))
8181    precision , recall , _  =  precision_recall_curve (y_true .numpy (), scores .numpy ())
8282    P  =  torch .tensor (0 )
8383    ignite_recall , ignite_precision  =  m ._compute_recall_and_precision (y_true , scores , P )
@@ -147,7 +147,7 @@ def test_compute_nonbinary_data(class_mean):
147147
148148    # Multilabel 
149149    m  =  MeanAveragePrecision (is_multilabel = True , class_mean = class_mean )
150-     y_true  =  torch .randint (0 , 2 , (130 , 5 , 2 , 2 )). bool () 
150+     y_true  =  torch .randint (0 , 2 , (130 , 5 , 2 , 2 ))
151151    m .update ((scores [:50 ], y_true [:50 ]))
152152    m .update ((scores [50 :], y_true [50 :]))
153153    ignite_map  =  m .compute ().numpy ()
0 commit comments