- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 160
 
Open
Labels
questionUser queriesUser queries
Description
Hi,
Can you help me debug why this SDE would throw errors for SRK solvers, but works and integrates fine with ERK and Milstein?
Here is a simplified version of the code:
import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
# os.environ["JAX_TRACEBACK_FILTERING"] = "off"
# os.environ["EQX_ON_ERROR"] = "breakpoint"
import jax
import diffrax as dfx
import jax.numpy as jnp
import time
SEED = 0
KEY = jax.random.PRNGKey(SEED)
m = 0.01 # inertia
gamma = 0.1 # viscosity
amplitude = 0.42 # amplitude of the driving force
omega = 1 # frequency of the driving force
drive_period = 2 * jnp.pi / omega 
alpha = -1 # linear spring constant
beta = 1 # cubic spring constant
sigma = 0.123 # noise intensity
x0 = 1.0 # initial position
v0 = 0.0 # initial velocity
state0 = jnp.array([x0, v0])
t_min = 0.0
t_max = 2**(10) * drive_period
dt = 2 **(-8) * drive_period
def functional_duffing(t: float, state: jnp.array,
                    args: list[float])->jnp.array:
    x,v = state
    dx = v
    
    gamma, alpha, beta, amplitude, omega, m = args
    driving = amplitude * jnp.cos(omega * t)
    damping = gamma * v
    spring = alpha * x + beta * x ** 3
    dv = (driving - damping - spring)/m
    dstate = jnp.array([dx, dv])
    return dstate
KEY, noise_key = jax.random.split(KEY)
term = dfx.ODETerm(functional_duffing)
args = [gamma, alpha, beta, amplitude, omega, m]
brownian_noise= dfx.VirtualBrownianTree(t_min, t_max, tol=1e-3, shape=(), key=noise_key)
def noise(t, y, args):
    return jnp.array([0, sigma])
noise_term = dfx.ControlTerm(noise, brownian_noise)
terms = dfx.MultiTerm(term, noise_term)
solver = dfx.ShARK()
saveat = dfx.SaveAt(ts = jnp.arange(t_min, t_max, dt))
begin = time.time()
sol = dfx.diffeqsolve(terms, solver, t_min, t_max, dt, state0, args, saveat=saveat, max_steps= 2**20)
end = time.time()
print(f"Elapsed time: {end-begin:.2f} s")throws this error:
ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure diffrax._term.MultiTerm[tuple[diffrax._term.ODETerm, diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]]
but I am already using the multiTerm(odeTerm, controlTerm) format unless I am misunderstanding something.
Also this same simulation runs much faster in Mathematica(KloedenPlatenSchurz method), any suggestions on how to speed this up would be very helpful
Thanks for this great library!
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries