Skip to content

Commit f49a6c5

Browse files
committed
Numba fallback non-implemented RVs
Closes #1245
1 parent abaf123 commit f49a6c5

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from pytensor.scalar.basic import ScalarType
2222
from pytensor.sparse import SparseTensorType
23+
from pytensor.tensor.random.type import RandomGeneratorType
2324
from pytensor.tensor.type import TensorType
2425
from pytensor.tensor.utils import hash_from_ndarray
2526

@@ -129,8 +130,8 @@ def get_numba_type(
129130
return CSRMatrixType(numba_dtype)
130131
if pytensor_type.format == "csc":
131132
return CSCMatrixType(numba_dtype)
132-
133-
raise NotImplementedError()
133+
elif isinstance(pytensor_type, RandomGeneratorType):
134+
return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType")
134135
else:
135136
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
136137

pytensor/link/numba/dispatch/random.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.link.numba.dispatch import basic as numba_basic
1717
from pytensor.link.numba.dispatch.basic import (
1818
direct_cast,
19+
generate_fallback_impl,
1920
numba_funcify,
2021
register_funcify_and_cache_key,
2122
)
@@ -406,13 +407,24 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
406407

407408
[rv_node] = op.fgraph.apply_nodes
408409
rv_op: RandomVariable = rv_node.op
410+
411+
try:
412+
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
413+
except NotImplementedError:
414+
py_impl = generate_fallback_impl(rv_op, node=rv_node, **kwargs)
415+
416+
@numba_basic.numba_njit
417+
def fallback_rv(_core_shape, *args):
418+
return py_impl(*args)
419+
420+
return fallback_rv, None
421+
409422
size = rv_op.size_param(rv_node)
410423
dist_params = rv_op.dist_params(rv_node)
411424
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
412425
core_shape_len = get_vector_length(core_shape)
413426
inplace = rv_op.inplace
414427

415-
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
416428
nin = 1 + len(dist_params) # rng + params
417429
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)
418430

tests/link/numba/test_random.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_multivariate_normal():
257257
],
258258
pt.as_tensor([3, 2]),
259259
),
260-
pytest.param(
260+
(
261261
ptr.hypergeometric,
262262
[
263263
(
@@ -274,7 +274,6 @@ def test_multivariate_normal():
274274
),
275275
],
276276
pt.as_tensor([3, 2]),
277-
marks=pytest.mark.xfail, # Not implemented
278277
),
279278
(
280279
ptr.wald,
@@ -722,3 +721,34 @@ def test_repeated_args():
722721
final_node = fn.maker.fgraph.outputs[0].owner
723722
assert isinstance(final_node.op, RandomVariableWithCoreShape)
724723
assert final_node.inputs[-2] is final_node.inputs[-1]
724+
725+
726+
def test_rv_fallback():
727+
"""Test that random variables can fallback to object mode."""
728+
729+
class CustomRV(ptr.RandomVariable):
730+
name = "custom"
731+
signature = "()->()"
732+
dtype = "float64"
733+
734+
def rng_fn(self, rng, value, size=None):
735+
# Just return the value plus a random number
736+
return value + rng.standard_normal(size=size)
737+
738+
custom_rv = CustomRV()
739+
740+
rng = shared(np.random.default_rng(123))
741+
size = pt.scalar("size", dtype=int)
742+
next_rng, x = custom_rv(np.pi, size=(size,), rng=rng).owner.outputs
743+
744+
fn = function([size], x, updates={rng: next_rng}, mode="NUMBA")
745+
746+
result1 = fn(1)
747+
result2 = fn(1)
748+
assert result1.shape == (1,)
749+
assert result1 != result2
750+
751+
large_sample = fn(1000)
752+
assert large_sample.shape == (1000,)
753+
np.testing.assert_allclose(large_sample.mean(), np.pi, rtol=1e-2)
754+
np.testing.assert_allclose(large_sample.std(), 1, rtol=1e-2)

0 commit comments

Comments
 (0)