Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap(jnp.asarray)(numpy_array) does not return a JAX array #25745

Open
shoyer opened this issue Jan 7, 2025 · 1 comment · May be fixed by #25835
Open

vmap(jnp.asarray)(numpy_array) does not return a JAX array #25745

shoyer opened this issue Jan 7, 2025 · 1 comment · May be fixed by #25835
Assignees
Labels
bug Something isn't working

Comments

@shoyer
Copy link
Collaborator

shoyer commented Jan 7, 2025

Description

I was surprised to discover that vmap does not always convert NumPy arrays into JAX arrays. Instead, it sometimes (rarely) will return NumPy arrays when given NumPy arrays as inputs:

import jax
import jax.numpy as jnp
import numpy as np

numpy_array = np.arange(3)
print(type(jax.vmap(lambda x: x)(numpy_array)))  # <class 'numpy.ndarray'>
print(type(jax.vmap(jnp.asarray)(numpy_array)))  # <class 'numpy.ndarray'>
print(type(jax.vmap(jnp.array)(numpy_array)))  # <class 'jaxlib.xla_extension.ArrayImpl'>

The exceptions appear to be cases where the wrapped function is evaluated to the "identity", as least as far as JAXprs are concerned:

print(jax.make_jaxpr(jax.vmap(jnp.asarray))(jax_array))
# { lambda ; a:i32[3]. let  in (a,) }

However, jnp.asarray() is also explicitly converting inputs into a JAX arrays, so I found this doubly surprsing.

My real-use is a custom pytree type where the constructor (but not tree_unflatten) always converts inputs into JAX arrays, which otherwise seems perfectly well behaved and is one of the suggested patterns in the JAX docs. vmap returns an otherwise impossible to create value when applied to this "identity function":

@jax.tree_util.register_pytree_node_class
class Wrapper:
  def __init__(self, value):
    self.value = jnp.asarray(value)

  def tree_flatten(self):
    return [self.value], None
  
  @classmethod
  def tree_unflatten(cls, _, leaves):
    result = super().__new__(cls)
    [result.value] = leaves
    return result

  def __repr__(self):
    return f"Wrapped({self.value!r})"


print(type(jax.vmap(Wrapper)(np.arange(3)).value))
# <class 'numpy.ndarray'>

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.38
jaxlib: 0.4.38
numpy: 1.26.4
python: 3.10.12 (main, Nov 6 2024, 20:22:13) [GCC 11.4.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='a953559aba37', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Thu Jun 27 21:05:47 UTC 2024', machine='x86_64')

@shoyer shoyer added the bug Something isn't working label Jan 7, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 7, 2025

Thanks for the report! I think this is probably related to #18020

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants