Skip to content

Diffrax v0.3.0

Compare
Choose a tag to compare
@github-actions github-actions released this 21 Feb 03:30
· 173 commits to main since this release
9280c3a

Highlights

This release is primarily a performance improvement: the default adjoint method now uses an asymptotically more efficient checkpointing implementation.

New features

  • Added diffrax.citation for automatically generating BibTeX references of the numerical methods being used.
  • diffrax.SaveAt can now save different selections of outputs at different times, using diffrax.SubSaveAt.
  • diffrax.SaveAt now supports a fn argument for controlling what to save, e.g. only statistics of the solution. (#113, #221, thanks @joglekara in #220!)
  • Can now use SaveAt(dense=True) in the edge case when t0 == t1.

Performance improvements

  • The default adjoint method RecursiveCheckpointAdjoint now uses a dramatically improved implementation for reverse-mode autodifferentiate while loops. This should be asymptotically faster, and generally produce both runtime and compiletime speed-ups.
    • The previous implementation is available as DirectAdjoint. This is still useful in a handful of less-common cases, such as using forward-mode autodifferentiation. (Once JAX gets bounded while loops as native operations then this will be tidied up further.)

Backward-incompatible changes

  • Removed NoAdjoint. It existed as a performance improvement when not using autodifferentiation, but RecursiveCheckpointAdjoint (the default) has now incorporated this performance improvement automatically.
  • Removed ConstantStepSize(compile_steps=...) and StepTo(compile_steps=...), as these are now unnecessarily when using the new RecursiveCheckpointAdjoint.
  • Removed the undocumented Fehlberg2 solver. (It's just not useful compared to Heun/Midpoint/Ralston.)
  • AbstractSolver.term_structure should now be e.g. (ODETerm, AbstractTerm) rather than jtu.tree_structure((ODETerm, AbstractTerm)), i.e. it now encodes the term type as well.
  • Dropped support for Python 3.7.

Fixes

  • Fixed an upstream change in JAX that was breaking UnsafeBrownianPath and VirtualBrownianTree (#225).
  • The sum of Runge--Kutta stages now happens in HIGHEST precision, which should improve numerical stability on some accelerators.

Examples

  • The documentation now has an introductory "coupled ODEs" example.
  • The documentation now has an advanced "nonlinear heat PDE" example.
  • The "symbolic regression" example has been updated to use sympy2jax.

New Contributors

Full Changelog: v0.2.2...v0.3.0