eval_shape
incompatible with deserializing directly to host due to __assert_same
#861
Labels
feature
New feature
Suppose you have a large pytree where you want to ensure that the full tree is never on the JAX device (TPU/GPU). You might also want to minimize the allocation of transient arrays by using
eval_shape
. Your ser/de code would then look something like this:But that errors with:
(The error message is slightly misleading in this case because the actual comparison we're performing and failing is between
jaxlib.xla_extension.ArrayImpl
(i.e.array_impl_type
) andnumpy.ndarray
.)Note that users can circumvent the issue by monkey-patching out the check but that's pretty ugly:
The text was updated successfully, but these errors were encountered: