Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Some annotated functions fail to be used inline with @qjit #1077

Open
josh146 opened this issue Aug 29, 2024 · 1 comment
Open

[BUG] Some annotated functions fail to be used inline with @qjit #1077

josh146 opened this issue Aug 29, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@josh146
Copy link
Member

josh146 commented Aug 29, 2024

For example,

>>> qjit(jax.scipy.linalg.expm)(x)
TypeError: Argument 'ArrayLike' of type <class 'str'> is not a valid JAX type

The same JAX function works fine when used within a defined function that is qjitted.

This is because of unexpected type annotations:

def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array:
@josh146 josh146 added the bug Something isn't working label Aug 29, 2024
dime10 pushed a commit that referenced this issue Aug 29, 2024
**Context:** It is not entirely clear why, but some [functions return
strings in their function
annotations](https://github.com/python/cpython/blob/3.10/Lib/inspect.py#L2117-L2120).
This is not enough to perform AOT compilation, so we should not do AOT
compilation in these cases.

**Description of the Change:** Limit AOT compilation to instances of
type and jax.core.ShapedArray.

**Benefits:** `expm` (and similar functions) can be called as
`qjit(expm)`

**Possible Drawbacks:** None

**Related GitHub Issues:** #1077
@paul0403
Copy link
Contributor

Is this still open?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants