From 80e3a7ea98192a989f4d9463bccb1d7bb89539a6 Mon Sep 17 00:00:00 2001 From: mattlevine22 Date: Wed, 13 Nov 2024 15:15:54 -0500 Subject: [PATCH] check-and-clip to [0,1] in linear_rescale to protect against floating point errors that sometimes occur on GPUs. --- diffrax/_misc.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/diffrax/_misc.py b/diffrax/_misc.py index ac61b813..1548d5b2 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -73,18 +73,36 @@ def fill_forward( return ys +@jax.custom_jvp +def _clip(x, cond): + return cast(jax.Array, jnp.where(cond, jnp.clip(x, 0, 1), x)) + + +@_clip.defjvp +def _clip_jvp(inputs, tangents): + (x, cond) = inputs + (t_x, _) = tangents + return _clip(x, cond), t_x + + def linear_rescale(t0, t, t1) -> Array: - """Calculates (t - t0) / (t1 - t0), assuming t0 <= t <= t1. + """Calculates (t - t0) / (t1 - t0). Specially handles the edge case t0 == t1: - zero is returned; - gradients through all three arguments are zero. + - output conditionally clipped to be in [0,1] to protect + from floating point errors. """ cond = t0 == t1 numerator = cast(Array, jnp.where(cond, 0, t - t0)) denominator = cast(Array, jnp.where(cond, 1, t1 - t0)) - return numerator / denominator + out = numerator / denominator + # We need to clip due to GPU computations sometimes giving `(x/x) > 1` + # https://github.com/jax-ml/jax/issues/24807 + clip_cond = ((t0 <= t) & (t <= t1)) | ((t1 <= t) & (t <= t0)) + return _clip(out, clip_cond) def adjoint_rms_seminorm(x: tuple[PyTree, PyTree, PyTree, PyTree]) -> RealScalarLike: