File tree Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Original file line number Diff line number Diff line change @@ -2185,21 +2185,23 @@ def test_ScalarFromTensor(cast_policy):
21852185 v = eval_outputs ([ss ])
21862186
21872187 assert v == 56
2188- assert v .shape == ()
2189-
2190- if cast_policy == "custom" :
2191- assert isinstance (v , np .int8 )
2192- elif cast_policy == "numpy+floatX" :
2193- assert isinstance (v , np .int64 )
2188+ assert isinstance (
2189+ v , int
2190+ ) # Numba unboxes scalars to python numerical primitives
2191+ # assert v.shape == ()
2192+ # if cast_policy == "custom":
2193+ # assert isinstance(v, np.int8)
2194+ # elif cast_policy == "numpy+floatX":
2195+ # assert isinstance(v, np.int64)
21942196
21952197 pts = lscalar ()
21962198 ss = scalar_from_tensor (pts )
21972199 ss .owner .op .grad ([pts ], [ss ])
21982200 fff = function ([pts ], ss )
21992201 v = fff (np .asarray (5 ))
22002202 assert v == 5
2201- assert isinstance (v , np .int64 )
2202- assert v .shape == ()
2203+ # assert isinstance(v, np.int64)
2204+ # assert v.shape == ()
22032205
22042206 with pytest .raises (TypeError ):
22052207 scalar_from_tensor (vector ())
You can’t perform that action at this time.
0 commit comments