You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using @qjit(autograph=True) decorator to loop over Pytree types that are not jax.numpy (like python builtin list, tuple, and numpy array) fails. No conversion will be applied.
Converting non-jax array types into jax ones manually will pass.
Thanks for reporting this! I think the missing element here is allowing Pytree types through the autograph process, arbitrary types would be out of scope.
tzunghanjuang
changed the title
Autograph fails when iterating over non-Jax array types
Autograph fails when iterating over non-Jax Pytree types
Jul 3, 2024
Issue description
Using
@qjit(autograph=True)
decorator to loop over Pytree types that are notjax.numpy
(like python builtin list, tuple, and numpy array) fails. No conversion will be applied.Converting non-jax array types into jax ones manually will pass.
Source code and tracebacks
Source code:
Trace:
The error is triggered here:
catalyst/frontend/catalyst/autograph/ag_primitives.py
Lines 159 to 169 in 140cbd3
The text was updated successfully, but these errors were encountered: