Skip to content

Commit 9124b72

Browse files
committed
Numba fallback cython missing dtype
1 parent a5fb911 commit 9124b72

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,19 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
6969

7070
cython_func = getattr(scipy.special.cython_special, scalar_func_name, None)
7171
if cython_func is not None:
72-
scalar_func_numba = wrap_cython_function(
73-
cython_func, output_dtype, input_dtypes
74-
)
75-
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch
76-
input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes()
77-
output_inner_dtype = scalar_func_numba.numpy_output_dtype()
72+
try:
73+
scalar_func_numba = wrap_cython_function(
74+
cython_func, output_dtype, input_dtypes
75+
)
76+
except NotImplementedError:
77+
pass
78+
else:
79+
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch
80+
input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes()
81+
output_inner_dtype = scalar_func_numba.numpy_output_dtype()
7882

7983
if scalar_func_numba is None:
80-
scalar_func_numba = generate_fallback_impl(op, node, **kwargs)
84+
return generate_fallback_impl(op, node, **kwargs), None
8185

8286
scalar_op_fn_name = get_name_for_object(scalar_func_numba)
8387
prefix = "x" if scalar_func_name != "x" else "y"

tests/link/numba/test_scalar.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import numpy as np
22
import pytest
3+
import scipy
34

45
import pytensor.scalar as ps
56
import pytensor.scalar.basic as psb
67
import pytensor.scalar.math as psm
78
import pytensor.tensor as pt
89
from pytensor import config, function
10+
from pytensor.graph import Apply
11+
from pytensor.scalar import UnaryScalarOp
912
from pytensor.scalar.basic import Composite
1013
from pytensor.tensor import tensor
1114
from pytensor.tensor.elemwise import Elemwise
@@ -184,3 +187,32 @@ def test_Softplus(dtype):
184187
strict=True,
185188
err_msg=f"Failed for value {value}",
186189
)
190+
191+
192+
def test_cython_obj_mode_fallback():
193+
"""Test that unsupported cython signatures fallback to obj-mode"""
194+
195+
# Create a ScalarOp with a non-standard dtype
196+
class IntegerGamma(UnaryScalarOp):
197+
# We'll try to check for scipy cython impl
198+
nfunc_spec = ("scipy.special.gamma", 1, 1)
199+
200+
def make_node(self, x):
201+
x = psb.as_scalar(x)
202+
assert x.dtype == "int64"
203+
out = x.type()
204+
return Apply(self, [x], [out])
205+
206+
def impl(self, x):
207+
return scipy.special.gamma(x).astype("int64")
208+
209+
x = pt.scalar("x", dtype="int64")
210+
g = Elemwise(IntegerGamma())(x)
211+
assert g.type.dtype == "int64"
212+
213+
with pytest.warns(UserWarning, match="Numba will use object mode"):
214+
compare_numba_and_py(
215+
[x],
216+
[g],
217+
[np.array(5, dtype="int64")],
218+
)

0 commit comments

Comments
 (0)