Open
Description
Hi,
Can you help me debug why this SDE would throw errors for SRK solvers, but works and integrates fine with ERK and Milstein?
Here is a simplified version of the code:
import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
# os.environ["JAX_TRACEBACK_FILTERING"] = "off"
# os.environ["EQX_ON_ERROR"] = "breakpoint"
import jax
import diffrax as dfx
import jax.numpy as jnp
import time
SEED = 0
KEY = jax.random.PRNGKey(SEED)
m = 0.01 # inertia
gamma = 0.1 # viscosity
amplitude = 0.42 # amplitude of the driving force
omega = 1 # frequency of the driving force
drive_period = 2 * jnp.pi / omega
alpha = -1 # linear spring constant
beta = 1 # cubic spring constant
sigma = 0.123 # noise intensity
x0 = 1.0 # initial position
v0 = 0.0 # initial velocity
state0 = jnp.array([x0, v0])
t_min = 0.0
t_max = 2**(10) * drive_period
dt = 2 **(-8) * drive_period
def functional_duffing(t: float, state: jnp.array,
args: list[float])->jnp.array:
x,v = state
dx = v
gamma, alpha, beta, amplitude, omega, m = args
driving = amplitude * jnp.cos(omega * t)
damping = gamma * v
spring = alpha * x + beta * x ** 3
dv = (driving - damping - spring)/m
dstate = jnp.array([dx, dv])
return dstate
KEY, noise_key = jax.random.split(KEY)
term = dfx.ODETerm(functional_duffing)
args = [gamma, alpha, beta, amplitude, omega, m]
brownian_noise= dfx.VirtualBrownianTree(t_min, t_max, tol=1e-3, shape=(), key=noise_key)
def noise(t, y, args):
return jnp.array([0, sigma])
noise_term = dfx.ControlTerm(noise, brownian_noise)
terms = dfx.MultiTerm(term, noise_term)
solver = dfx.ShARK()
saveat = dfx.SaveAt(ts = jnp.arange(t_min, t_max, dt))
begin = time.time()
sol = dfx.diffeqsolve(terms, solver, t_min, t_max, dt, state0, args, saveat=saveat, max_steps= 2**20)
end = time.time()
print(f"Elapsed time: {end-begin:.2f} s")
throws this error:
ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure diffrax._term.MultiTerm[tuple[diffrax._term.ODETerm, diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]]
but I am already using the multiTerm(odeTerm, controlTerm) format unless I am misunderstanding something.
Also this same simulation runs much faster in Mathematica(KloedenPlatenSchurz method), any suggestions on how to speed this up would be very helpful
Thanks for this great library!