Skip to content

Commit

Permalink
Merge pull request #252 from patrick-kidger/nan-memory-size
Browse files Browse the repository at this point in the history
Fixed issue #250.
  • Loading branch information
patrick-kidger authored May 8, 2023
2 parents edd1250 + 0dfe3e5 commit f101e75
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions diffrax/global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,40 +318,32 @@ def evaluate(
if t1 is not None:
return self.evaluate(t1, left=left) - self.evaluate(t0, left=left)
t = t0 * self.direction
ts_0 = self.ts[0]
ts_1 = self.ts[self.ts_size - 1]
pred = (self.ts_size > 1) & (t >= ts_0) & (t <= ts_1)
eval_fn = ft.partial(self.__class__._evaluate, t=t, left=left)
nan_fn = self.__class__._nan
# Use cond to avoid generating nans unless we have to.
out = lax.cond(pred, eval_fn, nan_fn, self)
t_bounded = self._nan_if_out_of_bounds(t)
out = self._get_local_interpolation(t_bounded, left).evaluate(
t_bounded, left=left
)
keep = ft.partial(jnp.where, (t == self.t0_if_trivial) & (self.ts_size == 1))
return jtu.tree_map(keep, self.y0_if_trivial, out)

@eqx.filter_jit
def derivative(self, t: Scalar, left: bool = True) -> PyTree:
t = t * self.direction
t = self._nan_if_out_of_bounds(t)
out = self._get_local_interpolation(t, left).derivative(t, left=left)
return (self.direction * out**ω).ω

def _nan_if_out_of_bounds(self, t):
# Note that len(self.ts) == max_steps + 1 > 0 so the indexing is always valid,
# even if we throw it away because self.ts_size == 0.
ts_0 = self.ts[0]
ts_1 = self.ts[self.ts_size - 1]
pred = (self.ts_size > 1) & (t >= ts_0) & (t <= ts_1)
deriv_fn = ft.partial(self.__class__._derivative, t=t, left=left)
nan_fn = self.__class__._nan
# Use cond to avoid generating nans unless we have to.
return lax.cond(pred, deriv_fn, nan_fn, self)

def _evaluate(self, t, left):
return self._get_local_interpolation(t, left).evaluate(t, left=left)

def _derivative(self, t, left):
out = self._get_local_interpolation(t, left).derivative(t, left=left)
return (self.direction * out**ω).ω

def _nan(self):
return jtu.tree_map(
ft.partial(jnp.full_like, fill_value=jnp.nan), self.y0_if_trivial
)
out_of_bounds = (self.ts_size <= 1) | (t < ts_0) | (t > ts_1)
make_nans = lambda t: jnp.where(out_of_bounds, jnp.nan, t)
identity = lambda t: t
# Avoid making NaNs unless we have to, by using a cond.
# (For the sake of JAX_DEBUG_NANS.)
t = lax.cond(eqxi.unvmap_any(out_of_bounds), make_nans, identity, t)
return t

@property
def t0(self):
Expand Down

0 comments on commit f101e75

Please sign in to comment.