Skip to content

Commit 463cba4

Browse files
committed
Numba does not output numpy scalars
1 parent a5e7747 commit 463cba4

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

tests/tensor/test_basic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2185,21 +2185,23 @@ def test_ScalarFromTensor(cast_policy):
21852185
v = eval_outputs([ss])
21862186

21872187
assert v == 56
2188-
assert v.shape == ()
2189-
2190-
if cast_policy == "custom":
2191-
assert isinstance(v, np.int8)
2192-
elif cast_policy == "numpy+floatX":
2193-
assert isinstance(v, np.int64)
2188+
assert isinstance(
2189+
v, int
2190+
) # Numba unboxes scalars to python numerical primitives
2191+
# assert v.shape == ()
2192+
# if cast_policy == "custom":
2193+
# assert isinstance(v, np.int8)
2194+
# elif cast_policy == "numpy+floatX":
2195+
# assert isinstance(v, np.int64)
21942196

21952197
pts = lscalar()
21962198
ss = scalar_from_tensor(pts)
21972199
ss.owner.op.grad([pts], [ss])
21982200
fff = function([pts], ss)
21992201
v = fff(np.asarray(5))
22002202
assert v == 5
2201-
assert isinstance(v, np.int64)
2202-
assert v.shape == ()
2203+
# assert isinstance(v, np.int64)
2204+
# assert v.shape == ()
22032205

22042206
with pytest.raises(TypeError):
22052207
scalar_from_tensor(vector())

0 commit comments

Comments
 (0)