Skip to content

Commit

Permalink
Serialisation fix for float64 scalars, which otherwise get downcast b…
Browse files Browse the repository at this point in the history
…y JAX
  • Loading branch information
patrick-kidger committed Sep 29, 2023
1 parent 0b670c8 commit 557bf36
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 14 deletions.
11 changes: 6 additions & 5 deletions equinox/_serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ class TreePathError(RuntimeError):
path: tuple


TreePathError.__name__ = TreePathError.__qualname__ = "RuntimeError"


def _ordered_tree_map(
f: Callable[..., Any],
tree: Any,
Expand Down Expand Up @@ -133,8 +130,12 @@ def default_deserialise_filter_spec(f: BinaryIO, x: Any) -> Any:
return np.load(f)
elif is_array_like(x):
# np.generic gets deserialised directly as an array, so convert back to a scalar
# type here. Important to use `jnp` here to handle `bfloat16`.
return type(x)(jnp.load(f).item())
# type here.
# See also https://github.com/google/jax/issues/17858
out = np.load(f)
if isinstance(x, jax.dtypes.bfloat16):
out = out.view(jax.dtypes.bfloat16)
return type(x)(out.item())
else:
return x

Expand Down
41 changes: 32 additions & 9 deletions tests/test_serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import equinox as eqx

from .helpers import shaped_allclose


def _example_trees():
jax_array1 = jnp.array(1)
Expand Down Expand Up @@ -120,20 +118,45 @@ def test_generic_dtype_serialisation(getkey, tmp_path):
eqx.tree_serialise_leaves(tmp_path, jax_array)
like_jax_array = jnp.array(bfloat16(2))
loaded_jax_array = eqx.tree_deserialise_leaves(tmp_path, like_jax_array)
assert shaped_allclose(jax_array, loaded_jax_array)
assert jax_array.item() == loaded_jax_array.item()

tree = jnp.array(1), bfloat16(1), np.float32(1), jnp.array(1)
like_tree = jnp.array(2), bfloat16(2), np.float32(2), jnp.array(2)
tree = (
jnp.array(1e-8),
bfloat16(1e-8),
np.float32(1e-8),
jnp.array(1e-8),
np.float64(1e-8),
)
like_tree = (
jnp.array(2.0),
bfloat16(2),
np.float32(2),
jnp.array(2.0),
np.float64(2.0),
)

# Ensure we can round trip when we start with a scalar
eqx.tree_serialise_leaves(tmp_path, tree)
loaded_tree = eqx.tree_deserialise_leaves(tmp_path, like_tree)
assert shaped_allclose(loaded_tree, tree)
assert len(loaded_tree) == len(tree)
for a, b in zip(loaded_tree, tree):
assert type(a) is type(b)
assert a.item() == b.item()

# Ensure we can round trip when we start with a scalar that we've JAX JITed
eqx.tree_serialise_leaves(tmp_path, jax.jit(lambda x: x)(tree))
loaded_tree = eqx.tree_deserialise_leaves(tmp_path, like_tree)
assert shaped_allclose(loaded_tree, tree)
# `[:-1]` to skip float64, as JIT turns it into float32
eqx.tree_serialise_leaves(tmp_path, jax.jit(lambda x: x)(tree[:-1]))
loaded_tree = eqx.tree_deserialise_leaves(tmp_path, like_tree[:-1])
assert len(loaded_tree) == len(tree[:-1])
for a, b in zip(loaded_tree, tree):
assert type(a) is type(b)
assert a.item() == b.item()


def test_python_scalar(tmp_path):
eqx.tree_serialise_leaves(tmp_path, 1e-8)
out = eqx.tree_deserialise_leaves(tmp_path, 0.0)
assert out == 1e-8


def test_custom_leaf_serialisation(getkey, tmp_path):
Expand Down

0 comments on commit 557bf36

Please sign in to comment.