Description
We are encountering gradients that are incorrect in specific regime. Specifically, we have:
- A custom solver, where the error estimate depends on a second call to the drift function or times
- Adaptive stepping
Below is a simplified example. Basically, we just take Euler and do some trivial change for the sake of example (we have a more complicated solver, but have identified the root of the issue to be this here), but crucially it has a y error that depends on a recalculation of the drift function (note that with or without the stop gradients doesn't matter). There doesn't seem to be anything wrong with the PIDController since we also implemented a simple controlled and the same error shows up. If constant stepping is used, the gradients are accurate. Note that our finite difference is stable and we have tried epsilon from 1e-10 to 1e-3 and it shows consistent results. The primal values are correct, but there is a difference in the gradient.
import jax
jax.config.update("jax_enable_x64", True)
import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar
import equinox as eqx
from equinox.internal import ω
class Test(diffrax.AbstractItoSolver):
term_structure: ClassVar = diffrax.AbstractTerm
interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation
def order(self, terms):
return 1
def strong_order(self, terms):
return 0.5
def init(
self,
terms,
t0,
t1,
y0,
args,
):
return None
def func(
self,
terms,
t0,
y0,
args,
):
return terms.vf(t0, y0, args)
def step(
self,
terms,
t0,
t1,
y0,
args,
solver_state,
made_jump,
):
del made_jump
control = terms.contr(t0, t1)
y1 = (y0**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω
drift = terms
b = jax.lax.stop_gradient(drift.vf(t0, y0, args))
y_error = jax.lax.stop_gradient(jnp.linalg.norm(b) * (t1 - t0))
dense_info = dict(y0=y0, y1=y1)
return y1, y_error, dense_info, solver_state, diffrax.RESULTS.successful
t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])
tol = 1e-1
solver = Test()
cont = diffrax.PIDController(tol, tol, error_order=1.0)
def drift(t, X, args):
y1, y2 = X
dy1 = -273 / 512 * y1
dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
return jnp.array([dy1, dy2])
def solve(key, y0):
terms = diffrax.ODETerm(drift)
saveat = diffrax.SaveAt(t1=True)
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t1,
dt0=0.0001,
y0=y0,
saveat=saveat,
max_steps=1000,
stepsize_controller=cont,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)
return sol
def loss(y):
k = jax.random.key(0)
s = solve(k, y)
return jnp.sqrt(jnp.mean(s.ys ** 2)), s.stats
x0 = jnp.array([1.0, 1.0])
print(eqx.filter_value_and_grad(loss, has_aux=True)(x0))
def finite_diff(y):
eps = 1e-9
val1 = loss(jnp.array([y[0] + eps / 2, y[1]]))[0]
val2 = loss(jnp.array([y[0], y[1] + eps / 2]))[0]
val3 = loss(jnp.array([y[0] - eps / 2, y[1]]))[0]
val4 = loss(jnp.array([y[0], y[1] - eps / 2]))[0]
print(val1, val2, val3, val4)
return jnp.array([val1 - val3, val2 - val4]) / eps
print(finite_diff(x0))
prints
((Array(81.89217529, dtype=float64),
{'max_steps': 1000,
'num_accepted_steps': Array(682, dtype=int64, weak_type=True),
'num_rejected_steps': Array(44, dtype=int64, weak_type=True),
'num_steps': Array(726, dtype=int64, weak_type=True)}),
Array([-60.26947042, 142.16164571], dtype=float64))
81.8921752553513 81.89217537306844 81.89217532865639 81.89217521093927
Array([-73.30508822, 162.12916876], dtype=float64)
We see accurate primal, but inaccurate gradients (by enough that this cannot just be numerical noise, we have tried on an other problems and see larger differences as well). The error order is wrong too, but that shouldn't matter, since we should just converge wrong, not change the differentiability of it. Are we violating some requirement by using drift again? Everything should be differentiable (and we tried anywhere from 0 to many, many stop gradients around all error related terms and couldn't seem to get anything to happen).