Skip to content

Commit

Permalink
check-and-clip to [0,1] in linear_rescale to protect against floating…
Browse files Browse the repository at this point in the history
… point errors that sometimes occur on GPUs.
  • Loading branch information
mattlevine22 authored and patrick-kidger committed Nov 26, 2024
1 parent 161f2a6 commit 80e3a7e
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions diffrax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,36 @@ def fill_forward(
return ys


@jax.custom_jvp
def _clip(x, cond):
return cast(jax.Array, jnp.where(cond, jnp.clip(x, 0, 1), x))


@_clip.defjvp
def _clip_jvp(inputs, tangents):
(x, cond) = inputs
(t_x, _) = tangents
return _clip(x, cond), t_x


def linear_rescale(t0, t, t1) -> Array:
"""Calculates (t - t0) / (t1 - t0), assuming t0 <= t <= t1.
"""Calculates (t - t0) / (t1 - t0).
Specially handles the edge case t0 == t1:
- zero is returned;
- gradients through all three arguments are zero.
- output conditionally clipped to be in [0,1] to protect
from floating point errors.
"""

cond = t0 == t1
numerator = cast(Array, jnp.where(cond, 0, t - t0))
denominator = cast(Array, jnp.where(cond, 1, t1 - t0))
return numerator / denominator
out = numerator / denominator
# We need to clip due to GPU computations sometimes giving `(x/x) > 1`
# https://github.com/jax-ml/jax/issues/24807
clip_cond = ((t0 <= t) & (t <= t1)) | ((t1 <= t) & (t <= t0))
return _clip(out, clip_cond)


def adjoint_rms_seminorm(x: tuple[PyTree, PyTree, PyTree, PyTree]) -> RealScalarLike:
Expand Down

0 comments on commit 80e3a7e

Please sign in to comment.