From 285e93e6acb1772de55fb078116c398134122354 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sat, 31 May 2025 17:18:16 +0200 Subject: [PATCH 1/2] Adding workaround so more graphs with Blockwise(Scans) can be vectorized --- pytensor/scan/op.py | 4 ++++ tests/scan/test_basic.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 2c3f404449..b287bc652d 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1500,6 +1500,10 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): node_input_storage = [storage_map[r] for r in node.inputs] node_output_storage = [storage_map[r] for r in node.outputs] + # HACK: Here to handle Blockwise Scans + if compute_map is None: + compute_map = {out: [False] for out in node.outputs} + # Analyse the compile inner function to determine which inputs and # outputs are on the gpu and speed up some checks during the execution outs_is_tensor = [ diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 351c2e703a..896d131f57 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -27,6 +27,7 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian +from pytensor.graph import vectorize_graph from pytensor.graph.basic import Apply, ancestors, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op @@ -1178,6 +1179,17 @@ def get_sum_of_grad(input0, input1): utt.verify_grad(get_sum_of_grad, inputs_test_values, rng=rng) + def test_blockwise_scan(self): + x = pt.tensor("x", shape=()) + out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10) + x_vec = pt.tensor("x_vec", shape=(None,)) + out_vec = vectorize_graph(out, {x: x_vec}) + + fn = function([x_vec], out_vec) + o1 = fn([1, 2, 3]) + o2 = np.arange(2, 12) + np.arange(3).reshape(-1, 1) + assert np.allclose(o1, o2) + def test_connection_pattern(self): """Test `Scan.connection_pattern` in the presence of recurrent outputs with multiple taps.""" From 914e6e4721c539fe11d69d17787a98dced5856f9 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 1 Jun 2025 11:35:37 +0200 Subject: [PATCH 2/2] Check for compute_map in rval --- pytensor/scan/op.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index b287bc652d..c1ae4db04d 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1500,10 +1500,6 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): node_input_storage = [storage_map[r] for r in node.inputs] node_output_storage = [storage_map[r] for r in node.outputs] - # HACK: Here to handle Blockwise Scans - if compute_map is None: - compute_map = {out: [False] for out in node.outputs} - # Analyse the compile inner function to determine which inputs and # outputs are on the gpu and speed up some checks during the execution outs_is_tensor = [ @@ -1651,8 +1647,9 @@ def rval( p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc ): r = p(n, [x[0] for x in i], o) - for o in node.outputs: - compute_map[o][0] = True + if compute_map is not None: + for o in node.outputs: + compute_map[o][0] = True if allow_gc: self.fn.free() return r