Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Dynamic shape arrays and reshaping #908

Open
josh146 opened this issue Jul 5, 2024 · 0 comments
Open

[Bug] Dynamic shape arrays and reshaping #908

josh146 opened this issue Jul 5, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@josh146
Copy link
Member

josh146 commented Jul 5, 2024

It turns out that there are several bugs involved in attempting to reshape dynamic shaped arrays. In particular:

  • jnp.reshape of a dynamically-shaped array with new compile-time shape results in the array being returned without any change in shape:

    @qjit(abstracted_axes={0: 'm', 1: 'n'}, keep_intermediate=True)
    def g(a):
         return jnp.reshape(a, (3, 5))
    
    a = jnp.ones([1, 3], dtype=float)
    g(a)

    I would expect the output to be of shape (3, 5) but with junk values for the out of bound elements.

  • jnp.reshape of a dynamically-shaped array with new dynamic shape results in a segfault during compilation. Lowering to MLIR seems to happen correctly, but we suspect that mhlo.dynamic_reshape doesn't have a lowering rule.

    @qjit(abstracted_axes={0: 'm', 1: 'n'}, keep_intermediate=True)
    def g(a):
         return jnp.reshape(a, (a.shape[1], a.shape[0]))
    
    a = jnp.ones([1, 3], dtype=float)
    g(a)
    • Separately, we should consider adding a proper exception to catch this, rather than the segfault killing the kernel.
  • jnp.reshape of a dynamically-shaped array within a loop iteration leads to a cryptic error even before compilation:

    >>> @qjit(abstracted_axes={0: 'm', 1: 'n'})
    ... def g(x):
    ...     @catalyst.for_loop(0, 10, 1, experimental_preserve_dimensions=False)
    ...     def loop(_, a):
    ...         return jnp.reshape(a, (3, 1))
    ...     return loop(x)
    >>> a = jnp.ones([1, 3], dtype=float)
    >>> g(a)
    ValueError: Too few leaves for PyTreeDef; expected 1, got 0

Originally posted by @josh146 in #904 (comment)

@josh146 josh146 added the bug Something isn't working label Jul 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant