Hi,
Can you help me debug why this SDE would throw errors for SRK solvers, but works and integrates fine with ERK and Milstein?
Here is a simplified version of the code:
import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
# os.environ["JAX_TRACEBACK_FILTERING"] = "off"
# os.environ["EQX_ON_ERROR"] = "breakpoint"
import jax
import diffrax as dfx
import jax.numpy as jnp
import time
SEED = 0
KEY = jax.random.PRNGKey(SEED)
m = 0.01 # inertia
gamma = 0.1 # viscosity
amplitude = 0.42 # amplitude of the driving force
omega = 1 # frequency of the driving force
drive_period = 2 * jnp.pi / omega
alpha = -1 # linear spring constant
beta = 1 # cubic spring constant
sigma = 0.123 # noise intensity
x0 = 1.0 # initial position
v0 = 0.0 # initial velocity
state0 = jnp.array([x0, v0])
t_min = 0.0
t_max = 2**(10) * drive_period
dt = 2 **(-8) * drive_period
def functional_duffing(t: float, state: jnp.array,
args: list[float])->jnp.array:
x,v = state
dx = v
gamma, alpha, beta, amplitude, omega, m = args
driving = amplitude * jnp.cos(omega * t)
damping = gamma * v
spring = alpha * x + beta * x ** 3
dv = (driving - damping - spring)/m
dstate = jnp.array([dx, dv])
return dstate
KEY, noise_key = jax.random.split(KEY)
term = dfx.ODETerm(functional_duffing)
args = [gamma, alpha, beta, amplitude, omega, m]
brownian_noise= dfx.VirtualBrownianTree(t_min, t_max, tol=1e-3, shape=(), key=noise_key)
def noise(t, y, args):
return jnp.array([0, sigma])
noise_term = dfx.ControlTerm(noise, brownian_noise)
terms = dfx.MultiTerm(term, noise_term)
solver = dfx.ShARK()
saveat = dfx.SaveAt(ts = jnp.arange(t_min, t_max, dt))
begin = time.time()
sol = dfx.diffeqsolve(terms, solver, t_min, t_max, dt, state0, args, saveat=saveat, max_steps= 2**20)
end = time.time()
print(f"Elapsed time: {end-begin:.2f} s")
throws this error:
ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure diffrax._term.MultiTerm[tuple[diffrax._term.ODETerm, diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]]
but I am already using the multiTerm(odeTerm, controlTerm) format unless I am misunderstanding something.
Also this same simulation runs much faster in Mathematica(KloedenPlatenSchurz method), any suggestions on how to speed this up would be very helpful
Thanks for this great library!
Hi,
Can you help me debug why this SDE would throw errors for SRK solvers, but works and integrates fine with ERK and Milstein?
Here is a simplified version of the code:
throws this error:
but I am already using the multiTerm(odeTerm, controlTerm) format unless I am misunderstanding something.
Also this same simulation runs much faster in Mathematica(KloedenPlatenSchurz method), any suggestions on how to speed this up would be very helpful
Thanks for this great library!