diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 2c3f404449..c1ae4db04d 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1647,8 +1647,9 @@ def rval( p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc ): r = p(n, [x[0] for x in i], o) - for o in node.outputs: - compute_map[o][0] = True + if compute_map is not None: + for o in node.outputs: + compute_map[o][0] = True if allow_gc: self.fn.free() return r diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 351c2e703a..896d131f57 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -27,6 +27,7 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian +from pytensor.graph import vectorize_graph from pytensor.graph.basic import Apply, ancestors, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op @@ -1178,6 +1179,17 @@ def get_sum_of_grad(input0, input1): utt.verify_grad(get_sum_of_grad, inputs_test_values, rng=rng) + def test_blockwise_scan(self): + x = pt.tensor("x", shape=()) + out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10) + x_vec = pt.tensor("x_vec", shape=(None,)) + out_vec = vectorize_graph(out, {x: x_vec}) + + fn = function([x_vec], out_vec) + o1 = fn([1, 2, 3]) + o2 = np.arange(2, 12) + np.arange(3).reshape(-1, 1) + assert np.allclose(o1, o2) + def test_connection_pattern(self): """Test `Scan.connection_pattern` in the presence of recurrent outputs with multiple taps."""