diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 7e4703c8df..8d2b4a71d0 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -16,6 +16,7 @@ get_name_for_object, unique_name_generator, ) +from pytensor.scalar import ScalarLoop from pytensor.scalar.basic import ( Add, Cast, @@ -69,7 +70,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): scalar_func_numba = wrap_cython_function( cython_func, output_dtype, input_dtypes ) - has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch + has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch() input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes() output_inner_dtype = scalar_func_numba.numpy_output_dtype() @@ -331,3 +332,46 @@ def softplus(x): return numba_basic.direct_cast(value, out_dtype) return softplus + + +@numba_funcify.register(ScalarLoop) +def numba_funcify_ScalarLoop(op, node, **kwargs): + inner_fn = numba_basic.numba_njit(numba_funcify(op.fgraph)) + + if op.is_while: + n_update = len(op.outputs) - 1 + + @numba_basic.numba_njit + def while_loop(n_steps, *inputs): + carry, constant = inputs[:n_update], inputs[n_update:] + + until = False + for i in range(n_steps): + outputs = inner_fn(*carry, *constant) + carry, until = outputs[:-1], outputs[-1] + if until: + break + + return *carry, until + + return while_loop + + else: + n_update = len(op.outputs) + + @numba_basic.numba_njit + def for_loop(n_steps, *inputs): + carry, constant = inputs[:n_update], inputs[n_update:] + + if n_steps < 0: + raise ValueError("ScalarLoop does not have a termination condition.") + + for i in range(n_steps): + carry = inner_fn(*carry, *constant) + + if n_update == 1: + return carry[0] + else: + return carry + + return for_loop diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 80168fd122..6ca30933ce 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -183,7 +183,7 @@ def perform(self, node, inputs, output_storage): inner_fn = self.py_perform_fn if self.is_while: - until = True + until = False for i in range(n_steps): *carry, until = inner_fn(*carry, *constant) if until: diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 25efd69a8d..d8aeffdd22 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -586,18 +586,46 @@ def test_elemwise_multiple_inplace_outs(): def test_scalar_loop(): - a = float64("a") - scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a]) + a_scalar = float64("a") + const_scalar = float64("const") + scalar_loop = pytensor.scalar.ScalarLoop( + init=[a_scalar], + update=[a_scalar + a_scalar + const_scalar], + constant=[const_scalar], + ) - x = pt.tensor("x", shape=(3,)) - elemwise_loop = Elemwise(scalar_loop)(3, x) + a = pt.tensor("a", shape=(3,)) + const = pt.tensor("const", shape=(3,)) + n_steps = 3 + elemwise_loop = Elemwise(scalar_loop)(n_steps, a, const) - with pytest.warns(UserWarning, match="object mode"): - compare_numba_and_py( - [x], - [elemwise_loop], - (np.array([1, 2, 3], dtype="float64"),), - ) + compare_numba_and_py( + [a, const], + [elemwise_loop], + [np.array([1, 2, 3], dtype="float64"), np.array([1, 1, 1], dtype="float64")], + ) + + +@pytest.mark.xfail( + reason="Numba fails due to https://github.com/numba/numba/issues/10098" +) +def test_gammainc_wrt_k_grad(): + x = pt.vector("x", dtype="float64") + k = pt.vector("k", dtype="float64") + + out = pt.gammainc(k, x) + grad_out = grad(out.sum(), k) + + compare_numba_and_py( + [x, k], + [grad_out], + # These values of x and k trigger all the branches in the gradient of gammainc + [ + np.array([0.0, 29.0, 31.0], dtype="float64"), + np.array([1.0, 13.0, 11.0], dtype="float64"), + ], + eval_obj_mode=False, + ) class TestsBenchmark: diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 2125d7cc0e..f1f1618e03 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -6,6 +6,7 @@ import pytensor.scalar.math as psm import pytensor.tensor as pt from pytensor import config, function +from pytensor.scalar import ScalarLoop from pytensor.scalar.basic import Composite from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise @@ -184,3 +185,75 @@ def test_Softplus(dtype): strict=True, err_msg=f"Failed for value {value}", ) + + +class TestScalarLoop: + def test_scalar_for_loop_single_out(self): + n_steps = ps.int64("n_steps") + x0 = ps.float64("x0") + const = ps.float64("const") + x = x0 + const + + op = ScalarLoop(init=[x0], constant=[const], update=[x]) + x = op(n_steps, x0, const) + + fn = function([n_steps, x0, const], [x], mode=numba_mode) + + res_x = fn(n_steps=5, x0=0, const=1) + np.testing.assert_allclose(res_x, 5) + + res_x = fn(n_steps=5, x0=0, const=2) + np.testing.assert_allclose(res_x, 10) + + res_x = fn(n_steps=4, x0=3, const=-1) + np.testing.assert_allclose(res_x, -1) + + def test_scalar_for_loop_multiple_outs(self): + n_steps = ps.int64("n_steps") + x0 = ps.float64("x0") + y0 = ps.int64("y0") + const = ps.float64("const") + x = x0 + const + y = y0 + 1 + + op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y]) + x, y = op(n_steps, x0, y0, const) + + fn = function([n_steps, x0, y0, const], [x, y], mode=numba_mode) + + res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1) + np.testing.assert_allclose(res_x, 5) + np.testing.assert_allclose(res_y, 5) + + res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2) + np.testing.assert_allclose(res_x, 10) + np.testing.assert_allclose(res_y, 5) + + res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1) + np.testing.assert_allclose(res_x, -1) + np.testing.assert_allclose(res_y, 6) + + def test_scalar_while_loop(self): + n_steps = ps.int64("n_steps") + x0 = ps.float64("x0") + x = x0 + 1 + until = x >= 10 + + op = ScalarLoop(init=[x0], update=[x], until=until) + fn = function([n_steps, x0], op(n_steps, x0), mode=numba_mode) + np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True]) + np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True]) + np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False]) + np.testing.assert_allclose(fn(n_steps=0, x0=1), [1, False]) + + @pytest.mark.xfail("Fails due to https://github.com/numba/numba/issues/10098") + def test_loop_with_cython_wrapped_op(self): + x = ps.float64("x") + op = ScalarLoop(init=[x], update=[ps.psi(x)]) + out = op(1, x) + + fn = function([x], out, mode=numba_mode) + x_test = np.float64(0.5) + res = fn(x_test) + expected_res = ps.psi(x).eval({x: x_test}) + np.testing.assert_allclose(res, expected_res) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index e7579b10ac..7e6976a1bd 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -381,8 +381,8 @@ def test_gammainc_ddk_tabulated_values(): # https://github.com/stan-dev/math/blob/21333bb70b669a1bd54d444ecbe1258078d33153/test/unit/math/prim/scal/fun/grad_reg_lower_inc_gamma_test.cpp k, x = pt.scalars("k", "x") gammainc_out = pt.gammainc(k, x) - gammaincc_ddk = pt.grad(gammainc_out, k) - f_grad = function([k, x], gammaincc_ddk) + gammainc_ddk = pt.grad(gammainc_out, k) + f_grad = function([k, x], gammainc_ddk) rtol = 1e-5 if config.floatX == "float64" else 1e-2 atol = 1e-10 if config.floatX == "float64" else 1e-6