diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index b0ecc6e6fb..941f9753d2 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -1,5 +1,6 @@ import typing import warnings +from functools import reduce from itertools import chain import numpy as np @@ -16,7 +17,6 @@ from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.utils import expand_empty, safe_new, until -from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import minimum from pytensor.tensor.shape import shape_padleft @@ -143,31 +143,6 @@ def _filter(x): raise ValueError(error_msg) -def isNaN_or_Inf_or_None(x): - isNone = x is None - try: - isNaN = np.isnan(x) - isInf = np.isinf(x) - isStr = isinstance(x, str) - except Exception: - isNaN = False - isInf = False - isStr = False - if not isNaN and not isInf: - try: - val = get_underlying_scalar_constant_value(x) - isInf = np.isinf(val) - isNaN = np.isnan(val) - except Exception: - isNaN = False - isInf = False - if isinstance(x, Constant) and isinstance(x.data, str): - isStr = True - else: - isStr = False - return isNone or isNaN or isInf or isStr - - def _manage_output_api_change(outputs, updates, return_updates): if return_updates: warnings.warn( @@ -505,7 +480,7 @@ def wrap_into_list(x): # This helper eagerly skips the Scan if n_steps is known to be 1 single_step_requested = False - if isinstance(n_steps, float | int): + if isinstance(n_steps, int | float): single_step_requested = n_steps == 1 else: try: @@ -676,33 +651,20 @@ def wrap_into_list(x): if nw_name is not None: nw_seq.name = nw_name - # Since we've added all sequences now we need to level them up based on - # n_steps or their different shapes - lengths_vec = [seq.shape[0] for seq in scan_seqs] - - if not isNaN_or_Inf_or_None(n_steps): - # ^ N_steps should also be considered - lengths_vec.append(pt.as_tensor(n_steps)) - - if len(lengths_vec) == 0: - # ^ No information about the number of steps - raise ValueError( - "No information about the number of steps " - "provided. Either provide a value for " - "n_steps argument of scan or provide an input " - "sequence" - ) - - # If the user has provided the number of steps, do that regardless ( and - # raise an error if the sequences are not long enough ) - if isNaN_or_Inf_or_None(n_steps): - actual_n_steps = lengths_vec[0] - for contestant in lengths_vec[1:]: - actual_n_steps = minimum(actual_n_steps, contestant) + if n_steps is None: + if not scan_seqs: + raise ValueError( + "No information about the number of steps provided. " + "Either provide a value for n_steps argument of scan or provide an input sequence." + ) + actual_n_steps = reduce(minimum, [seq.shape[0] for seq in scan_seqs]) else: - actual_n_steps = pt.as_tensor(n_steps) + actual_n_steps = pt.as_tensor(n_steps, dtype="int64", ndim=0) + # Since we've added all sequences now we need to level them off based on + # n_steps or their different shapes scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs] + # Conventions : # mit_mot = multiple input taps, multiple output taps ( only provided # by the gradient function ) @@ -899,10 +861,8 @@ def wrap_into_list(x): raw_inner_outputs = fn(*args) condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs) - if condition is not None: - as_while = True - else: - as_while = False + as_while = condition is not None + ## # Step 3. Check if we actually need scan and remove it if we don't ## @@ -934,7 +894,7 @@ def wrap_into_list(x): # extract still missing inputs (there still might be so) and add them # as non sequences at the end of our args - if condition is not None: + if as_while: outputs.append(condition) fake_nonseqs = [x.type() for x in non_seqs] fake_outputs = clone_replace( @@ -1252,8 +1212,8 @@ def remove_dimensions(outs, offsets=None): rightOrder = ( mit_sot_rightOrder + sit_sot_rightOrder - + untraced_sit_sot_rightOrder + nit_sot_rightOrder + + untraced_sit_sot_rightOrder ) scan_out_list = [None] * len(rightOrder) for idx, pos in enumerate(rightOrder): diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index a8847b3cf6..6de6ce8737 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -4470,3 +4470,49 @@ def onestep(seq, seq_tm4): f_infershape = function([seq, init], out_seq_tm4[1].shape) scan_nodes_infershape = scan_nodes_from_fct(f_infershape) assert len(scan_nodes_infershape) == 0 + + +@pytest.mark.parametrize("single_step", (True, False)) +def test_scan_mapped_and_non_traced_output_ordering(single_step): + # Regression test for https://github.com/pymc-devs/pytensor/issues/1796 + + rng = random_generator_type("rng") + + def x_then_rng(rng): + next_rng, x = pt.random.normal(rng=rng).owner.outputs + return x, next_rng + + xs, final_rng = scan( + fn=x_then_rng, + outputs_info=[None, rng], + n_steps=1 if single_step else 5, + return_updates=False, + ) + assert isinstance(xs.type, TensorType) + assert isinstance(final_rng.type, RandomGeneratorType) + + def rng_then_x(rng): + x, next_rng = x_then_rng(rng) + return next_rng, x + + final_rng, xs = scan( + fn=rng_then_x, + outputs_info=[rng, None], + n_steps=1 if single_step else 5, + return_updates=False, + ) + assert isinstance(xs.type, TensorType) + assert isinstance(final_rng.type, RandomGeneratorType) + + def rng_between_xs(rng): + x, next_rng = x_then_rng(rng) + return x, next_rng, x + 1, x + 2 + + xs1, final_rng, xs2, xs3 = scan( + fn=rng_between_xs, + outputs_info=[None, rng, None, None], + n_steps=1 if single_step else 5, + return_updates=False, + ) + assert all(isinstance(xs.type, TensorType) for xs in (xs1, xs2, xs3)) + assert isinstance(final_rng.type, RandomGeneratorType)