Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eval_shape incompatible with deserializing directly to host due to __assert_same #861

Open
colehaus opened this issue Sep 23, 2024 · 3 comments
Labels
feature New feature

Comments

@colehaus
Copy link
Contributor

Suppose you have a large pytree where you want to ensure that the full tree is never on the JAX device (TPU/GPU). You might also want to minimize the allocation of transient arrays by using eval_shape. Your ser/de code would then look something like this:

from __future__ import annotations

from collections.abc import Callable
from pathlib import Path
from typing import TypeVar, TypeVarTuple

import equinox as eqx
import equinox._serialisation as eqx_ser
import jax
from numpy import ndarray

Shape = TypeVarTuple("Shape")
DType = TypeVar("DType")


def save(path: Path, array: ndarray[*Shape, DType]):
    with path.open("wb") as f:
        # We have to convert to JAX arrays because numpy doesn't handle bfloat16
        # https://github.com/jax-ml/ml_dtypes/issues/41
        eqx.tree_serialise_leaves(f, array)


def load(path: Path, like_fn: Callable[[], ndarray[*Shape, DType]]) -> ndarray[*Shape, DType]:
    with path.open("rb") as f:
        return eqx.tree_deserialise_leaves(
            f,
            eqx.filter_eval_shape(like_fn),
            filter_spec=lambda f, x: jax.device_get(eqx.default_deserialise_filter_spec(f, x)),
        )

But that errors with:

File …/lib/python3.11/site-packages/equinox/_serialisation.py:172, in _assert_same.<locals>._assert_same_impl(path, new, old)
    170     typeold = array_impl_type
    171 if typenew is not typeold:
--> 172     raise RuntimeError(
    173         f"Deserialised leaf at path '{jtu.keystr(path)}' has changed type from "
    174         f"{type(old)} in `like` to {type(new)} on disk."
    175     )
    176 if isinstance(new, (np.ndarray, jax.Array)):
    177     if new.shape != old.shape:

RuntimeError: Deserialised leaf at path '' has changed type from <class 'jax._src.api.ShapeDtypeStruct'> in `like` to <class 'numpy.ndarray'> on disk.

(The error message is slightly misleading in this case because the actual comparison we're performing and failing is between jaxlib.xla_extension.ArrayImpl (i.e. array_impl_type) and numpy.ndarray.)

Note that users can circumvent the issue by monkey-patching out the check but that's pretty ugly:

def patched_assert_same(array_impl_type):  # type: ignore
    """Equinox generates a fixed `array_impl_type` that corresponds to a JAX array.
    Then `_assert_same_impl` swaps in this type for any `jax.ShapeDtypeStruct` in the `like`.
    It then compares the `like` types and types at the very end of deserialization.
    But this means we're forbidden from deserializing to the host with `eval_shape` and
    would instead have to deserialize the whole tree on device and transfer to the host.
    """

    def _assert_same_impl(path, new, old):  # type: ignore
        pass

    return _assert_same_impl


eqx_ser._assert_same = patched_assert_same  # type: ignore
@patrick-kidger
Copy link
Owner

Hmm, I'm a little mystified by this, because this was something I thought we added support for (#259, c5fc44f).

Indeed in the line just above your error, we have an explicit

if typeold is jax.ShapeDtypeStruct:
    typeold = array_impl_type

check to cast away ShapeDtypeStructs.

@colehaus
Copy link
Contributor Author

Ah, yeah, I think the issue is because we're in a slightly unusual case where we actually want a numpy/host array returned while array_impl_type assumes we want a JAX/device array. If I remove the jax.device_get part on the custom filter_spec, then it works fine.

@patrick-kidger
Copy link
Owner

Right!
So I think what you're trying to do here is reasonable. I'd be happy to take a PR adjusting this. (Maybe we just consider all kinds of JAX and NumPy array interchangeable?)

@patrick-kidger patrick-kidger added the feature New feature label Sep 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants