Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograph fails when iterating over non-Jax Pytree types #896

Open
tzunghanjuang opened this issue Jul 3, 2024 · 1 comment
Open

Autograph fails when iterating over non-Jax Pytree types #896

tzunghanjuang opened this issue Jul 3, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@tzunghanjuang
Copy link
Collaborator

tzunghanjuang commented Jul 3, 2024

Issue description

Using @qjit(autograph=True) decorator to loop over Pytree types that are not jax.numpy (like python builtin list, tuple, and numpy array) fails. No conversion will be applied.
Converting non-jax array types into jax ones manually will pass.

Source code and tracebacks

Source code:

from catalyst import qjit

def updateList(x):
    return [x[0]+1, x[1]+2]

@qjit(autograph=True)
def fn(x):
    for i in range(4):
        x = updateList(x)
    return x

fn([1, 2])

Trace:

[~/catalyst/frontend/catalyst/autograph/ag_primitives.py:347]: UserWarning: Tracing of an AutoGraph converted for loop failed with an exception:
  AutoGraphError:    The variable 'x' was initialized with type <class 'list'>, which is not compatible with JAX. Typically, this is the case for non-numeric values.
    You may still use such a variable as a constant inside a loop, but it cannot be updated from one iteration to the next, or accessed outside the loop scope if it was defined inside of it.

The error ocurred within the body of the following for loop statement:
  File "/tmp/ipykernel_165289/1920264522.py", line 8, in fn
    for i in range(4):

If you intended for the conversion to happen, make sure that the (now dynamic) loop variable is not used in tracing-incompatible ways, for instance by indexing a Python list with it. In that case, the list should be wrapped into an array.
To understand different types of JAX tracing errors, please refer to the guide at: https://jax.readthedocs.io/en/latest/errors.html

If you did not intend for the conversion to happen, you may safely ignore this warning.
  warnings.warn(

The error is triggered here:

try:
jax.api_util.shaped_abstractify(inp)
except TypeError as e:
raise AutoGraphError(
f"The variable '{symbol_names[i]}' was initialized with type {type(inp)}, "
"which is not compatible with JAX. Typically, this is the case for non-numeric "
"values.\n"
"You may still use such a variable as a constant inside a loop, but it cannot "
"be updated from one iteration to the next, or accessed outside the loop scope "
"if it was defined inside of it."
) from e

@tzunghanjuang tzunghanjuang added the bug Something isn't working label Jul 3, 2024
@dime10
Copy link
Collaborator

dime10 commented Jul 3, 2024

Thanks for reporting this! I think the missing element here is allowing Pytree types through the autograph process, arbitrary types would be out of scope.

@tzunghanjuang tzunghanjuang changed the title Autograph fails when iterating over non-Jax array types Autograph fails when iterating over non-Jax Pytree types Jul 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants