Skip to content

Additive SDE throws error with SRK style solvers #474

Open
@ParticularlyPythonicBS

Description

@ParticularlyPythonicBS

Hi,
Can you help me debug why this SDE would throw errors for SRK solvers, but works and integrates fine with ERK and Milstein?
Here is a simplified version of the code:

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
# os.environ["JAX_TRACEBACK_FILTERING"] = "off"
# os.environ["EQX_ON_ERROR"] = "breakpoint"

import jax
import diffrax as dfx
import jax.numpy as jnp
import time

SEED = 0
KEY = jax.random.PRNGKey(SEED)

m = 0.01 # inertia
gamma = 0.1 # viscosity

amplitude = 0.42 # amplitude of the driving force
omega = 1 # frequency of the driving force
drive_period = 2 * jnp.pi / omega 

alpha = -1 # linear spring constant
beta = 1 # cubic spring constant

sigma = 0.123 # noise intensity

x0 = 1.0 # initial position
v0 = 0.0 # initial velocity
state0 = jnp.array([x0, v0])

t_min = 0.0
t_max = 2**(10) * drive_period
dt = 2 **(-8) * drive_period

def functional_duffing(t: float, state: jnp.array,
                    args: list[float])->jnp.array:
    x,v = state
    dx = v
    
    gamma, alpha, beta, amplitude, omega, m = args

    driving = amplitude * jnp.cos(omega * t)
    damping = gamma * v
    spring = alpha * x + beta * x ** 3
    dv = (driving - damping - spring)/m

    dstate = jnp.array([dx, dv])
    return dstate

KEY, noise_key = jax.random.split(KEY)
term = dfx.ODETerm(functional_duffing)
args = [gamma, alpha, beta, amplitude, omega, m]

brownian_noise= dfx.VirtualBrownianTree(t_min, t_max, tol=1e-3, shape=(), key=noise_key)
def noise(t, y, args):
    return jnp.array([0, sigma])

noise_term = dfx.ControlTerm(noise, brownian_noise)
terms = dfx.MultiTerm(term, noise_term)
solver = dfx.ShARK()
saveat = dfx.SaveAt(ts = jnp.arange(t_min, t_max, dt))

begin = time.time()
sol = dfx.diffeqsolve(terms, solver, t_min, t_max, dt, state0, args, saveat=saveat, max_steps= 2**20)
end = time.time()
print(f"Elapsed time: {end-begin:.2f} s")

throws this error:

ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure diffrax._term.MultiTerm[tuple[diffrax._term.ODETerm, diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]]

but I am already using the multiTerm(odeTerm, controlTerm) format unless I am misunderstanding something.

Also this same simulation runs much faster in Mathematica(KloedenPlatenSchurz method), any suggestions on how to speed this up would be very helpful

Thanks for this great library!

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions