Diffrax v0.6.0
Features
-
Continuous events! It is now possible to specify a condition at which point the differential equation should halt. For example, here's one finding the time at which a dropped ball hits the ground:
import diffrax import jax.numpy as jnp import optimistix as optx def vector_field(t, y, args): _, v = y return jnp.array([v, -9.81]) def cond_fn(t, y, args, **kwargs): x, _ = y return x term = diffrax.ODETerm(vector_field) solver = diffrax.Tsit5() t0 = 0 t1 = jnp.inf dt0 = 0.1 y0 = jnp.array([10.0, 0.0]) root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm) event = diffrax.Event(cond_fn, root_finder) sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event) print(f"Event time: {sol.ts[0]}") # Event time: 1.42... print(f"Velocity at event time: {sol.ys[0, 1]}") # Velocity at event time: -14.00...
When
cond_fn
hits zero, the solve stops. Once the event changes sign, then we use Optimistix to do a root find to locate the exact time at which the equation should terminate. Event handling is also fully differentiable.Getting this in was a huge amount of work from @cholberg -- thank you! -- and has been one of our longest-requested features for a while, so I'm really happy to have this in.
(We previously only had 'discrete events', which just terminated at the end of a step, and did not do a root find.)
See the events page in the documentation for more.
-
Simulation of space-time-time Lévy area. This is a higher-order statistic of Brownian motion, used in some advanced SDE solvers. We don't have any such solvers yet, but watch this space... ;)
This was a hugely impressive technical effort from @andyElking. Check out our arXiv paper on the topic, which discusses the technical nitty-gritty of how these statistics can be simulated in an efficient manner.
-
ControlTerm
now supports returning a Lineax linear operator. For example, here's how to easily create a diagonal diffusion term:def vector_field(t, y, args): # y is a JAX array of shape (2,) y1, y2 = y diagonal = jnp.array([y2, y1]) return lineax.DiagonalLinearOperator(diagonal) # corresponds to the matrix [[y2, 0], [0, y1]] diffusion_term = ControlTerm(vector_field, ...)
This should make it much easier to express SDEs with particular structure to their diffusion matrices.
This is particularly good for efficiency reasons: the operator-specified
.mv
(matrix-vector product) method is used, which typically provides a more efficient implementation than that given by filling in some zeros and using a dense matrix-vector product.Thank you to @lockwo for implementing this one!
See the documentation on
ControlTerm
for more.
Deprecations
Two APIs have now been deprecated.
Both of these APIs now have compatibility layers, so existing code should continue to work. However, they will now emit deprecation warnings, and users are encouraged to upgrade. These APIs may be removed at a later date.
-
diffeqsolve(..., discrete_terminating_event=...)
, along with the corresponding classesAbstractDiscreteTerminatingEvent
+DiscreteTerminatingEvent
+SteadyStateEvent
. These have been superseded bydiffeqsolve(..., event=Event(...))
. -
WeaklyDiagonalControlTerm
has been superseded by the new behaviour forControlTerm
, and its interaction with Lineax, as discussed above.
Other
- Now working around an upstream bug introduced in JAX 0.4.29+, so we should be compatible with modern JAX releases.
- No longer emitting warnings coming from JAX deprecating a few old APIs. (We've migrated to the new ones.)
Full Changelog: v0.5.1...v0.6.0