Skip to content

Implement numba dispatch for ScalarLoop #1445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_name_for_object,
unique_name_generator,
)
from pytensor.scalar import ScalarLoop
from pytensor.scalar.basic import (
Add,
Cast,
Expand Down Expand Up @@ -69,7 +70,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
scalar_func_numba = wrap_cython_function(
cython_func, output_dtype, input_dtypes
)
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch
has_pyx_skip_dispatch = scalar_func_numba.has_pyx_skip_dispatch()
input_inner_dtypes = scalar_func_numba.numpy_arg_dtypes()
output_inner_dtype = scalar_func_numba.numpy_output_dtype()

Expand Down Expand Up @@ -331,3 +332,46 @@ def softplus(x):
return numba_basic.direct_cast(value, out_dtype)

return softplus


@numba_funcify.register(ScalarLoop)
def numba_funcify_ScalarLoop(op, node, **kwargs):
inner_fn = numba_basic.numba_njit(numba_funcify(op.fgraph))

if op.is_while:
n_update = len(op.outputs) - 1

@numba_basic.numba_njit
def while_loop(n_steps, *inputs):
carry, constant = inputs[:n_update], inputs[n_update:]

until = False
for i in range(n_steps):
outputs = inner_fn(*carry, *constant)
carry, until = outputs[:-1], outputs[-1]
if until:
break

return *carry, until

return while_loop

else:
n_update = len(op.outputs)

@numba_basic.numba_njit
def for_loop(n_steps, *inputs):
carry, constant = inputs[:n_update], inputs[n_update:]

if n_steps < 0:
raise ValueError("ScalarLoop does not have a termination condition.")

for i in range(n_steps):
carry = inner_fn(*carry, *constant)

if n_update == 1:
return carry[0]
else:
return carry

return for_loop
2 changes: 1 addition & 1 deletion pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def perform(self, node, inputs, output_storage):
inner_fn = self.py_perform_fn

if self.is_while:
until = True
until = False
for i in range(n_steps):
*carry, until = inner_fn(*carry, *constant)
if until:
Expand Down
48 changes: 38 additions & 10 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,18 +586,46 @@ def test_elemwise_multiple_inplace_outs():


def test_scalar_loop():
a = float64("a")
scalar_loop = pytensor.scalar.ScalarLoop([a], [a + a])
a_scalar = float64("a")
const_scalar = float64("const")
scalar_loop = pytensor.scalar.ScalarLoop(
init=[a_scalar],
update=[a_scalar + a_scalar + const_scalar],
constant=[const_scalar],
)

x = pt.tensor("x", shape=(3,))
elemwise_loop = Elemwise(scalar_loop)(3, x)
a = pt.tensor("a", shape=(3,))
const = pt.tensor("const", shape=(3,))
n_steps = 3
elemwise_loop = Elemwise(scalar_loop)(n_steps, a, const)

with pytest.warns(UserWarning, match="object mode"):
compare_numba_and_py(
[x],
[elemwise_loop],
(np.array([1, 2, 3], dtype="float64"),),
)
compare_numba_and_py(
[a, const],
[elemwise_loop],
[np.array([1, 2, 3], dtype="float64"), np.array([1, 1, 1], dtype="float64")],
)


@pytest.mark.xfail(
reason="Numba fails due to https://github.com/numba/numba/issues/10098"
)
def test_gammainc_wrt_k_grad():
x = pt.vector("x", dtype="float64")
k = pt.vector("k", dtype="float64")

out = pt.gammainc(k, x)
grad_out = grad(out.sum(), k)

compare_numba_and_py(
[x, k],
[grad_out],
# These values of x and k trigger all the branches in the gradient of gammainc
[
np.array([0.0, 29.0, 31.0], dtype="float64"),
np.array([1.0, 13.0, 11.0], dtype="float64"),
],
eval_obj_mode=False,
)


class TestsBenchmark:
Expand Down
73 changes: 73 additions & 0 deletions tests/link/numba/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytensor.scalar.math as psm
import pytensor.tensor as pt
from pytensor import config, function
from pytensor.scalar import ScalarLoop
from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise
Expand Down Expand Up @@ -184,3 +185,75 @@ def test_Softplus(dtype):
strict=True,
err_msg=f"Failed for value {value}",
)


class TestScalarLoop:
def test_scalar_for_loop_single_out(self):
n_steps = ps.int64("n_steps")
x0 = ps.float64("x0")
const = ps.float64("const")
x = x0 + const

op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)

fn = function([n_steps, x0, const], [x], mode=numba_mode)

res_x = fn(n_steps=5, x0=0, const=1)
np.testing.assert_allclose(res_x, 5)

res_x = fn(n_steps=5, x0=0, const=2)
np.testing.assert_allclose(res_x, 10)

res_x = fn(n_steps=4, x0=3, const=-1)
np.testing.assert_allclose(res_x, -1)

def test_scalar_for_loop_multiple_outs(self):
n_steps = ps.int64("n_steps")
x0 = ps.float64("x0")
y0 = ps.int64("y0")
const = ps.float64("const")
x = x0 + const
y = y0 + 1

op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y])
x, y = op(n_steps, x0, y0, const)

fn = function([n_steps, x0, y0, const], [x, y], mode=numba_mode)

res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1)
np.testing.assert_allclose(res_x, 5)
np.testing.assert_allclose(res_y, 5)

res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2)
np.testing.assert_allclose(res_x, 10)
np.testing.assert_allclose(res_y, 5)

res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1)
np.testing.assert_allclose(res_x, -1)
np.testing.assert_allclose(res_y, 6)

def test_scalar_while_loop(self):
n_steps = ps.int64("n_steps")
x0 = ps.float64("x0")
x = x0 + 1
until = x >= 10

op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=numba_mode)
np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True])
np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True])
np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False])
np.testing.assert_allclose(fn(n_steps=0, x0=1), [1, False])

@pytest.mark.xfail("Fails due to https://github.com/numba/numba/issues/10098")
def test_loop_with_cython_wrapped_op(self):
x = ps.float64("x")
op = ScalarLoop(init=[x], update=[ps.psi(x)])
out = op(1, x)

fn = function([x], out, mode=numba_mode)
x_test = np.float64(0.5)
res = fn(x_test)
expected_res = ps.psi(x).eval({x: x_test})
np.testing.assert_allclose(res, expected_res)
4 changes: 2 additions & 2 deletions tests/tensor/test_math_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ def test_gammainc_ddk_tabulated_values():
# https://github.com/stan-dev/math/blob/21333bb70b669a1bd54d444ecbe1258078d33153/test/unit/math/prim/scal/fun/grad_reg_lower_inc_gamma_test.cpp
k, x = pt.scalars("k", "x")
gammainc_out = pt.gammainc(k, x)
gammaincc_ddk = pt.grad(gammainc_out, k)
f_grad = function([k, x], gammaincc_ddk)
gammainc_ddk = pt.grad(gammainc_out, k)
f_grad = function([k, x], gammainc_ddk)

rtol = 1e-5 if config.floatX == "float64" else 1e-2
atol = 1e-10 if config.floatX == "float64" else 1e-6
Expand Down
Loading