You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for the nice package. I have a question regarding the implementation of diffrax.diffeqsolve. Specifically, I want to know a bit more details about the ODE solver, especially on what are the key factors driving diffeqsolve to be faster. In a nutshell, can you list several bulletpoints on the code implementation optimizations that have been done to accelerate ODE solving in diffrax?
The text was updated successfully, but these errors were encountered:
There's been a lot of JAX-specific tricks that have gone into this: for example knowing when the compiler will want to make a copy of a buffer, and avoiding those cases. Or knowing that vmap-of-cond becomes jnp.where, and likewise knowing to avoid that.
I think one interesting trick here is the way we make sure to be vmap-friendly. That is, if we do vmap(diffeqsolve), we'd like the result to be fast. The key trick here is that we don't do a loop-over-output-times (integrating in between), and instead do a loop-over-steps (outputting times as we go along). This avoids having a double-while loop, with the inner loop sitting and waiting at every output time until every batch element is done.
Hi,
Thanks for the nice package. I have a question regarding the implementation of
diffrax.diffeqsolve
. Specifically, I want to know a bit more details about the ODE solver, especially on what are the key factors drivingdiffeqsolve
to be faster. In a nutshell, can you list several bulletpoints on the code implementation optimizations that have been done to accelerate ODE solving indiffrax
?The text was updated successfully, but these errors were encountered: