@@ -923,22 +923,18 @@ def test_infer_static_shape():
923923class TestEye :
924924 # This is slow for the ('int8', 3) version.
925925 def test_basic (self ):
926- def check (dtype , N , M_ = None , k = 0 ):
927- # PyTensor does not accept None as a tensor.
928- # So we must use a real value.
929- M = M_
930- # Currently DebugMode does not support None as inputs even if this is
931- # allowed.
932- if M is None and config .mode in ["DebugMode" , "DEBUG_MODE" ]:
933- M = N
926+ def check (dtype , N , M = None , k = 0 ):
934927 N_symb = iscalar ()
935928 M_symb = iscalar ()
936929 k_symb = iscalar ()
930+ test_inputs = [N , k ] if M is None else [N , M , k ]
931+ inputs = [N_symb , k_symb ] if M is None else [N_symb , M_symb , k_symb ]
937932 f = function (
938- [N_symb , M_symb , k_symb ], eye (N_symb , M_symb , k_symb , dtype = dtype )
933+ inputs ,
934+ eye (N_symb , None if (M is None ) else M_symb , k_symb , dtype = dtype ),
939935 )
940- result = f (N , M , k )
941- assert np .allclose (result , np .eye (N , M_ , k , dtype = dtype ))
936+ result = f (* test_inputs )
937+ assert np .allclose (result , np .eye (N , M , k , dtype = dtype ))
942938 assert result .dtype == np .dtype (dtype )
943939
944940 for dtype in ALL_DTYPES :
@@ -1744,7 +1740,7 @@ def test_join_matrixV_negative_axis(self):
17441740 got = f (- 2 )
17451741 assert np .allclose (got , want )
17461742
1747- with pytest .raises (ValueError ):
1743+ with pytest .raises (( ValueError , IndexError ) ):
17481744 f (- 3 )
17491745
17501746 @pytest .mark .parametrize ("py_impl" , (False , True ))
0 commit comments