From 557bf36b782b7ff98541919cee424d5054ec92db Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 29 Sep 2023 12:08:13 -0700 Subject: [PATCH] Serialisation fix for float64 scalars, which otherwise get downcast by JAX --- equinox/_serialisation.py | 11 +++++----- tests/test_serialisation.py | 41 +++++++++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/equinox/_serialisation.py b/equinox/_serialisation.py index a7ade996..534337d8 100644 --- a/equinox/_serialisation.py +++ b/equinox/_serialisation.py @@ -17,9 +17,6 @@ class TreePathError(RuntimeError): path: tuple -TreePathError.__name__ = TreePathError.__qualname__ = "RuntimeError" - - def _ordered_tree_map( f: Callable[..., Any], tree: Any, @@ -133,8 +130,12 @@ def default_deserialise_filter_spec(f: BinaryIO, x: Any) -> Any: return np.load(f) elif is_array_like(x): # np.generic gets deserialised directly as an array, so convert back to a scalar - # type here. Important to use `jnp` here to handle `bfloat16`. - return type(x)(jnp.load(f).item()) + # type here. + # See also https://github.com/google/jax/issues/17858 + out = np.load(f) + if isinstance(x, jax.dtypes.bfloat16): + out = out.view(jax.dtypes.bfloat16) + return type(x)(out.item()) else: return x diff --git a/tests/test_serialisation.py b/tests/test_serialisation.py index b33021cb..c5158b6c 100644 --- a/tests/test_serialisation.py +++ b/tests/test_serialisation.py @@ -8,8 +8,6 @@ import equinox as eqx -from .helpers import shaped_allclose - def _example_trees(): jax_array1 = jnp.array(1) @@ -120,20 +118,45 @@ def test_generic_dtype_serialisation(getkey, tmp_path): eqx.tree_serialise_leaves(tmp_path, jax_array) like_jax_array = jnp.array(bfloat16(2)) loaded_jax_array = eqx.tree_deserialise_leaves(tmp_path, like_jax_array) - assert shaped_allclose(jax_array, loaded_jax_array) + assert jax_array.item() == loaded_jax_array.item() - tree = jnp.array(1), bfloat16(1), np.float32(1), jnp.array(1) - like_tree = jnp.array(2), bfloat16(2), np.float32(2), jnp.array(2) + tree = ( + jnp.array(1e-8), + bfloat16(1e-8), + np.float32(1e-8), + jnp.array(1e-8), + np.float64(1e-8), + ) + like_tree = ( + jnp.array(2.0), + bfloat16(2), + np.float32(2), + jnp.array(2.0), + np.float64(2.0), + ) # Ensure we can round trip when we start with a scalar eqx.tree_serialise_leaves(tmp_path, tree) loaded_tree = eqx.tree_deserialise_leaves(tmp_path, like_tree) - assert shaped_allclose(loaded_tree, tree) + assert len(loaded_tree) == len(tree) + for a, b in zip(loaded_tree, tree): + assert type(a) is type(b) + assert a.item() == b.item() # Ensure we can round trip when we start with a scalar that we've JAX JITed - eqx.tree_serialise_leaves(tmp_path, jax.jit(lambda x: x)(tree)) - loaded_tree = eqx.tree_deserialise_leaves(tmp_path, like_tree) - assert shaped_allclose(loaded_tree, tree) + # `[:-1]` to skip float64, as JIT turns it into float32 + eqx.tree_serialise_leaves(tmp_path, jax.jit(lambda x: x)(tree[:-1])) + loaded_tree = eqx.tree_deserialise_leaves(tmp_path, like_tree[:-1]) + assert len(loaded_tree) == len(tree[:-1]) + for a, b in zip(loaded_tree, tree): + assert type(a) is type(b) + assert a.item() == b.item() + + +def test_python_scalar(tmp_path): + eqx.tree_serialise_leaves(tmp_path, 1e-8) + out = eqx.tree_deserialise_leaves(tmp_path, 0.0) + assert out == 1e-8 def test_custom_leaf_serialisation(getkey, tmp_path):