Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jaxtyping and PyTrees with mixed array and non-array leaves #8

Open
mlprt opened this issue Feb 16, 2024 · 1 comment
Open

Jaxtyping and PyTrees with mixed array and non-array leaves #8

mlprt opened this issue Feb 16, 2024 · 1 comment

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 16, 2024

Consider the function feedbax.tree_set:

feedbax/feedbax/_tree.py

Lines 103 to 133 in 147fb42

def tree_set(
tree: PyTree[Any | Shaped[Array, "batch *?dims"], "T"],
items: PyTree[Any | Shaped[Array, "*?dims"], "T"],
idx: int,
) -> PyTree[Any | Shaped[Array, "batch *?dims"], "T"]:
"""Perform an out-of-place update of each array leaf of a PyTree.
Non-array leaves are simply replaced by their matching leaves in `items`.
For example, if `tree` is a PyTree of states over time, whose first dimension
is the time step, and `items` is a PyTree of states for a single time step,
this function can be used to insert the latter into the former at a given time index.
Arguments:
tree: Any PyTree whose array leaves share a first dimension of the same
length, for example a batch dimension.
items: Any PyTree with the same structure as `tree`, and whose array
leaves have the same shape as the corresponding leaves in `tree`,
but lacking the first dimension.
idx: The index along the first dimension of the array leaves of `tree`
into which to insert the array leaves of `items`.
Returns:
A PyTree with the same structure as `tree`, where the array leaves of `items` have been inserted as the `idx`-th elements of the corresponding array leaves of `tree`.
"""
arrays = eqx.filter(tree, eqx.is_array)
vals_update, other_update = eqx.partition(
items, jax.tree_map(lambda x: x is not None, arrays)
)
arrays_update = jax.tree_map(lambda xs, x: xs.at[idx].set(x), arrays, vals_update)
return eqx.combine(arrays_update, other_update)

It takes 1) PyTree of array and non-array leaves, where the array leaves all share a batch dimension, 2) a PyTree of the same structure, but where the array leaves all lack the batch dimension, 3) an index into the batch dimension. It returns a copy of (1) where the arrays of (2) have been inserted at (3), for all the array leaves.

The Any is included to allow for non-array leaves. The problem is that Any | Array is equivalent to Any, so that the jaxtyping PyTree/Array annotations will never lead to errors.

Is there a way to do array shape checking with jaxtyping, while allowing for non-array leaves?

@mlprt
Copy link
Owner Author

mlprt commented Feb 29, 2024

One way I could see this working is if we insist that this function be applied after tree has been partitioned into array and non-array leaves. Then we can type it as Array | None (or is the None automatic?) instead of Array | Any.

@mlprt mlprt removed the TRANSFER label Mar 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant