-
Notifications
You must be signed in to change notification settings - Fork 152
Fix output ordering in new Scan API #1799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ricardoV94
merged 2 commits into
pymc-devs:main
from
ricardoV94:fix_scan_new_api_output_ordering
Dec 21, 2025
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import typing | ||
| import warnings | ||
| from functools import reduce | ||
| from itertools import chain | ||
|
|
||
| import numpy as np | ||
|
|
@@ -16,7 +17,6 @@ | |
| from pytensor.graph.utils import MissingInputError, TestValueError | ||
| from pytensor.scan.op import Scan, ScanInfo | ||
| from pytensor.scan.utils import expand_empty, safe_new, until | ||
| from pytensor.tensor.basic import get_underlying_scalar_constant_value | ||
| from pytensor.tensor.exceptions import NotScalarConstantError | ||
| from pytensor.tensor.math import minimum | ||
| from pytensor.tensor.shape import shape_padleft | ||
|
|
@@ -143,31 +143,6 @@ def _filter(x): | |
| raise ValueError(error_msg) | ||
|
|
||
|
|
||
| def isNaN_or_Inf_or_None(x): | ||
| isNone = x is None | ||
| try: | ||
| isNaN = np.isnan(x) | ||
| isInf = np.isinf(x) | ||
| isStr = isinstance(x, str) | ||
| except Exception: | ||
| isNaN = False | ||
| isInf = False | ||
| isStr = False | ||
| if not isNaN and not isInf: | ||
| try: | ||
| val = get_underlying_scalar_constant_value(x) | ||
| isInf = np.isinf(val) | ||
| isNaN = np.isnan(val) | ||
| except Exception: | ||
| isNaN = False | ||
| isInf = False | ||
| if isinstance(x, Constant) and isinstance(x.data, str): | ||
| isStr = True | ||
| else: | ||
| isStr = False | ||
| return isNone or isNaN or isInf or isStr | ||
|
|
||
|
|
||
| def _manage_output_api_change(outputs, updates, return_updates): | ||
| if return_updates: | ||
| warnings.warn( | ||
|
|
@@ -505,7 +480,7 @@ def wrap_into_list(x): | |
|
|
||
| # This helper eagerly skips the Scan if n_steps is known to be 1 | ||
| single_step_requested = False | ||
| if isinstance(n_steps, float | int): | ||
| if isinstance(n_steps, int | float): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reordered, just because ints are much more likely than float |
||
| single_step_requested = n_steps == 1 | ||
| else: | ||
| try: | ||
|
|
@@ -676,33 +651,20 @@ def wrap_into_list(x): | |
| if nw_name is not None: | ||
| nw_seq.name = nw_name | ||
|
|
||
| # Since we've added all sequences now we need to level them up based on | ||
| # n_steps or their different shapes | ||
| lengths_vec = [seq.shape[0] for seq in scan_seqs] | ||
|
|
||
| if not isNaN_or_Inf_or_None(n_steps): | ||
| # ^ N_steps should also be considered | ||
| lengths_vec.append(pt.as_tensor(n_steps)) | ||
|
|
||
| if len(lengths_vec) == 0: | ||
| # ^ No information about the number of steps | ||
| raise ValueError( | ||
| "No information about the number of steps " | ||
| "provided. Either provide a value for " | ||
| "n_steps argument of scan or provide an input " | ||
| "sequence" | ||
| ) | ||
|
|
||
| # If the user has provided the number of steps, do that regardless ( and | ||
| # raise an error if the sequences are not long enough ) | ||
| if isNaN_or_Inf_or_None(n_steps): | ||
| actual_n_steps = lengths_vec[0] | ||
| for contestant in lengths_vec[1:]: | ||
| actual_n_steps = minimum(actual_n_steps, contestant) | ||
| if n_steps is None: | ||
| if not scan_seqs: | ||
| raise ValueError( | ||
| "No information about the number of steps provided. " | ||
| "Either provide a value for n_steps argument of scan or provide an input sequence." | ||
| ) | ||
| actual_n_steps = reduce(minimum, [seq.shape[0] for seq in scan_seqs]) | ||
| else: | ||
| actual_n_steps = pt.as_tensor(n_steps) | ||
| actual_n_steps = pt.as_tensor(n_steps, dtype="int64", ndim=0) | ||
|
|
||
| # Since we've added all sequences now we need to level them off based on | ||
| # n_steps or their different shapes | ||
| scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs] | ||
|
|
||
| # Conventions : | ||
| # mit_mot = multiple input taps, multiple output taps ( only provided | ||
| # by the gradient function ) | ||
|
|
@@ -899,10 +861,8 @@ def wrap_into_list(x): | |
| raw_inner_outputs = fn(*args) | ||
|
|
||
| condition, outputs, updates = get_updates_and_outputs(raw_inner_outputs) | ||
| if condition is not None: | ||
| as_while = True | ||
| else: | ||
| as_while = False | ||
| as_while = condition is not None | ||
|
|
||
| ## | ||
| # Step 3. Check if we actually need scan and remove it if we don't | ||
| ## | ||
|
|
@@ -934,7 +894,7 @@ def wrap_into_list(x): | |
|
|
||
| # extract still missing inputs (there still might be so) and add them | ||
| # as non sequences at the end of our args | ||
| if condition is not None: | ||
| if as_while: | ||
| outputs.append(condition) | ||
| fake_nonseqs = [x.type() for x in non_seqs] | ||
| fake_outputs = clone_replace( | ||
|
|
@@ -1252,8 +1212,8 @@ def remove_dimensions(outs, offsets=None): | |
| rightOrder = ( | ||
| mit_sot_rightOrder | ||
| + sit_sot_rightOrder | ||
| + untraced_sit_sot_rightOrder | ||
| + nit_sot_rightOrder | ||
| + untraced_sit_sot_rightOrder | ||
| ) | ||
| scan_out_list = [None] * len(rightOrder) | ||
| for idx, pos in enumerate(rightOrder): | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is silly. If anybody disagrees let me know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overriding a bad user input by guessing what he probably meant (by falling back to the sequence length and ignoring
n_steps) was always an anti-pattern, so I like removing this.