You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Perform an out-of-place update of each array leaf of a PyTree.
Non-array leaves are simply replaced by their matching leaves in `items`.
For example, if `tree` is a PyTree of states over time, whose first dimension
is the time step, and `items` is a PyTree of states for a single time step,
this function can be used to insert the latter into the former at a given time index.
Arguments:
tree: Any PyTree whose array leaves share a first dimension of the same
length, for example a batch dimension.
items: Any PyTree with the same structure as `tree`, and whose array
leaves have the same shape as the corresponding leaves in `tree`,
but lacking the first dimension.
idx: The index along the first dimension of the array leaves of `tree`
into which to insert the array leaves of `items`.
Returns:
A PyTree with the same structure as `tree`, where the array leaves of `items` have been inserted as the `idx`-th elements of the corresponding array leaves of `tree`.
It takes 1) PyTree of array and non-array leaves, where the array leaves all share a batch dimension, 2) a PyTree of the same structure, but where the array leaves all lack the batch dimension, 3) an index into the batch dimension. It returns a copy of (1) where the arrays of (2) have been inserted at (3), for all the array leaves.
The Any is included to allow for non-array leaves. The problem is that Any | Array is equivalent to Any, so that the jaxtyping PyTree/Array annotations will never lead to errors.
Is there a way to do array shape checking with jaxtyping, while allowing for non-array leaves?
The text was updated successfully, but these errors were encountered:
One way I could see this working is if we insist that this function be applied aftertree has been partitioned into array and non-array leaves. Then we can type it as Array | None (or is the None automatic?) instead of Array | Any.
Consider the function
feedbax.tree_set
:feedbax/feedbax/_tree.py
Lines 103 to 133 in 147fb42
It takes 1) PyTree of array and non-array leaves, where the array leaves all share a batch dimension, 2) a PyTree of the same structure, but where the array leaves all lack the batch dimension, 3) an index into the batch dimension. It returns a copy of (1) where the arrays of (2) have been inserted at (3), for all the array leaves.
The
Any
is included to allow for non-array leaves. The problem is thatAny | Array
is equivalent toAny
, so that the jaxtyping PyTree/Array annotations will never lead to errors.Is there a way to do array shape checking with jaxtyping, while allowing for non-array leaves?
The text was updated successfully, but these errors were encountered: