Skip to content

Commit abaf123

Browse files
committed
Numba fallback complex erf
1 parent 9124b72 commit abaf123

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,11 @@ def logp1mexp(x):
278278

279279

280280
@register_funcify_and_cache_key(Erf)
281-
def numba_funcify_Erf(op, **kwargs):
281+
def numba_funcify_Erf(op, node, **kwargs):
282+
if node.inputs[0].type.dtype.startswith("complex"):
283+
# Complex not supported by numba
284+
return numba_funcify_ScalarOp(op, node=node, **kwargs)
285+
282286
@numba_basic.numba_njit
283287
def erf(x):
284288
return math.erf(x)

tests/link/numba/test_scalar.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,14 @@ def impl(self, x):
216216
[g],
217217
[np.array(5, dtype="int64")],
218218
)
219+
220+
221+
def test_erf_complex():
222+
x = pt.scalar("x", dtype="complex128")
223+
g = pt.erf(x)
224+
225+
compare_numba_and_py(
226+
[x],
227+
[g],
228+
[np.array(0.5 + 1j, dtype="complex128")],
229+
)

0 commit comments

Comments
 (0)