Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 17 additions & 57 deletions pytensor/scan/basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing
import warnings
from functools import reduce
from itertools import chain

import numpy as np
Expand All @@ -16,7 +17,6 @@
from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft
Expand Down Expand Up @@ -143,31 +143,6 @@ def _filter(x):
raise ValueError(error_msg)


def isNaN_or_Inf_or_None(x):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is silly. If anybody disagrees let me know

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overriding a bad user input by guessing what he probably meant (by falling back to the sequence length and ignoring n_steps) was always an anti-pattern, so I like removing this.

isNone = x is None
try:
isNaN = np.isnan(x)
isInf = np.isinf(x)
isStr = isinstance(x, str)
except Exception:
isNaN = False
isInf = False
isStr = False
if not isNaN and not isInf:
try:
val = get_underlying_scalar_constant_value(x)
isInf = np.isinf(val)
isNaN = np.isnan(val)
except Exception:
isNaN = False
isInf = False
if isinstance(x, Constant) and isinstance(x.data, str):
isStr = True
else:
isStr = False
return isNone or isNaN or isInf or isStr


def _manage_output_api_change(outputs, updates, return_updates):
if return_updates:
warnings.warn(
Expand Down Expand Up @@ -505,7 +480,7 @@ def wrap_into_list(x):

# This helper eagerly skips the Scan if n_steps is known to be 1
single_step_requested = False
if isinstance(n_steps, float | int):
if isinstance(n_steps, int | float):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordered, just because ints are much more likely than float

single_step_requested = n_steps == 1
else:
try:
Expand Down Expand Up @@ -676,33 +651,20 @@ def wrap_into_list(x):
if nw_name is not None:
nw_seq.name = nw_name

# Since we've added all sequences now we need to level them up based on
# n_steps or their different shapes
lengths_vec = [seq.shape[0] for seq in scan_seqs]

if not isNaN_or_Inf_or_None(n_steps):
# ^ N_steps should also be considered
lengths_vec.append(pt.as_tensor(n_steps))

if len(lengths_vec) == 0:
# ^ No information about the number of steps
raise ValueError(
"No information about the number of steps "
"provided. Either provide a value for "
"n_steps argument of scan or provide an input "
"sequence"
)

# If the user has provided the number of steps, do that regardless ( and
# raise an error if the sequences are not long enough )
if isNaN_or_Inf_or_None(n_steps):
actual_n_steps = lengths_vec[0]
for contestant in lengths_vec[1:]:
actual_n_steps = minimum(actual_n_steps, contestant)
if n_steps is None:
if not scan_seqs:
raise ValueError(
"No information about the number of steps provided. "
"Either provide a value for n_steps argument of scan or provide an input sequence."
)
actual_n_steps = reduce(minimum, [seq.shape[0] for seq in scan_seqs])
else:
actual_n_steps = pt.as_tensor(n_steps)
actual_n_steps = pt.as_tensor(n_steps, dtype="int64", ndim=0)

# Since we've added all sequences now we need to level them off based on
# n_steps or their different shapes
scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]

# Conventions :
# mit_mot = multiple input taps, multiple output taps ( only provided
# by the gradient function )
Expand Down Expand Up @@ -899,10 +861,8 @@ def wrap_into_list(x):
raw_inner_outputs = fn(*args)

condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs)
if condition is not None:
as_while = True
else:
as_while = False
as_while = condition is not None

##
# Step 3. Check if we actually need scan and remove it if we don't
##
Expand Down Expand Up @@ -934,7 +894,7 @@ def wrap_into_list(x):

# extract still missing inputs (there still might be so) and add them
# as non sequences at the end of our args
if condition is not None:
if as_while:
outputs.append(condition)
fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = clone_replace(
Expand Down Expand Up @@ -1252,8 +1212,8 @@ def remove_dimensions(outs, offsets=None):
rightOrder = (
mit_sot_rightOrder
+ sit_sot_rightOrder
+ untraced_sit_sot_rightOrder
+ nit_sot_rightOrder
+ untraced_sit_sot_rightOrder
)
scan_out_list = [None] * len(rightOrder)
for idx, pos in enumerate(rightOrder):
Expand Down
46 changes: 46 additions & 0 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4470,3 +4470,49 @@ def onestep(seq, seq_tm4):
f_infershape = function([seq, init], out_seq_tm4[1].shape)
scan_nodes_infershape = scan_nodes_from_fct(f_infershape)
assert len(scan_nodes_infershape) == 0


@pytest.mark.parametrize("single_step", (True, False))
def test_scan_mapped_and_non_traced_output_ordering(single_step):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1796

rng = random_generator_type("rng")

def x_then_rng(rng):
next_rng, x = pt.random.normal(rng=rng).owner.outputs
return x, next_rng

xs, final_rng = scan(
fn=x_then_rng,
outputs_info=[None, rng],
n_steps=1 if single_step else 5,
return_updates=False,
)
assert isinstance(xs.type, TensorType)
assert isinstance(final_rng.type, RandomGeneratorType)

def rng_then_x(rng):
x, next_rng = x_then_rng(rng)
return next_rng, x

final_rng, xs = scan(
fn=rng_then_x,
outputs_info=[rng, None],
n_steps=1 if single_step else 5,
return_updates=False,
)
assert isinstance(xs.type, TensorType)
assert isinstance(final_rng.type, RandomGeneratorType)

def rng_between_xs(rng):
x, next_rng = x_then_rng(rng)
return x, next_rng, x + 1, x + 2

xs1, final_rng, xs2, xs3 = scan(
fn=rng_between_xs,
outputs_info=[None, rng, None, None],
n_steps=1 if single_step else 5,
return_updates=False,
)
assert all(isinstance(xs.type, TensorType) for xs in (xs1, xs2, xs3))
assert isinstance(final_rng.type, RandomGeneratorType)
Loading