Skip to content

Incorrect gradient in toy adaptive ODE #499

Open
@lockwo

Description

@lockwo

We are encountering gradients that are incorrect in specific regime. Specifically, we have:

  • A custom solver, where the error estimate depends on a second call to the drift function or times
  • Adaptive stepping

Below is a simplified example. Basically, we just take Euler and do some trivial change for the sake of example (we have a more complicated solver, but have identified the root of the issue to be this here), but crucially it has a y error that depends on a recalculation of the drift function (note that with or without the stop gradients doesn't matter). There doesn't seem to be anything wrong with the PIDController since we also implemented a simple controlled and the same error shows up. If constant stepping is used, the gradients are accurate. Note that our finite difference is stable and we have tried epsilon from 1e-10 to 1e-3 and it shows consistent results. The primal values are correct, but there is a difference in the gradient.

import jax

jax.config.update("jax_enable_x64", True)
import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar
import equinox as eqx
from equinox.internal import ω


class Test(diffrax.AbstractItoSolver):
    term_structure: ClassVar = diffrax.AbstractTerm
    interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation

    def order(self, terms):
        return 1

    def strong_order(self, terms):
        return 0.5

    def init(
        self,
        terms,
        t0,
        t1,
        y0,
        args,
    ):
        return None

    def func(
        self,
        terms,
        t0,
        y0,
        args,
    ):
        return terms.vf(t0, y0, args)

    def step(
        self,
        terms,
        t0,
        t1,
        y0,
        args,
        solver_state,
        made_jump,
    ):
        del made_jump
        control = terms.contr(t0, t1)
        y1 = (y0**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω

        drift = terms
        b = jax.lax.stop_gradient(drift.vf(t0, y0, args))
        y_error = jax.lax.stop_gradient(jnp.linalg.norm(b) * (t1 - t0))

        dense_info = dict(y0=y0, y1=y1)
        return y1, y_error, dense_info, solver_state, diffrax.RESULTS.successful

t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])
tol = 1e-1
solver = Test()
cont = diffrax.PIDController(tol, tol, error_order=1.0)


def drift(t, X, args):
    y1, y2 = X
    dy1 = -273 / 512 * y1
    dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
    return jnp.array([dy1, dy2])

def solve(key, y0):
    terms = diffrax.ODETerm(drift)
    saveat = diffrax.SaveAt(t1=True)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0=0.0001,
        y0=y0,
        saveat=saveat,
        max_steps=1000,
        stepsize_controller=cont,
        adjoint=diffrax.RecursiveCheckpointAdjoint(),
    )
    return sol

def loss(y):
    k = jax.random.key(0)
    s = solve(k, y)
    return jnp.sqrt(jnp.mean(s.ys ** 2)), s.stats

x0 = jnp.array([1.0, 1.0])
print(eqx.filter_value_and_grad(loss, has_aux=True)(x0))

def finite_diff(y):
    eps = 1e-9
    val1 = loss(jnp.array([y[0] + eps / 2, y[1]]))[0]
    val2 = loss(jnp.array([y[0], y[1] + eps / 2]))[0]
    val3 = loss(jnp.array([y[0] - eps / 2, y[1]]))[0]
    val4 = loss(jnp.array([y[0], y[1] - eps / 2]))[0]
    print(val1, val2, val3, val4)
    return jnp.array([val1 - val3, val2 - val4]) / eps

print(finite_diff(x0))

prints

((Array(81.89217529, dtype=float64),
  {'max_steps': 1000,
   'num_accepted_steps': Array(682, dtype=int64, weak_type=True),
   'num_rejected_steps': Array(44, dtype=int64, weak_type=True),
   'num_steps': Array(726, dtype=int64, weak_type=True)}),
 Array([-60.26947042, 142.16164571], dtype=float64))

81.8921752553513 81.89217537306844 81.89217532865639 81.89217521093927
Array([-73.30508822, 162.12916876], dtype=float64)

We see accurate primal, but inaccurate gradients (by enough that this cannot just be numerical noise, we have tried on an other problems and see larger differences as well). The error order is wrong too, but that shouldn't matter, since we should just converge wrong, not change the differentiability of it. Are we violating some requirement by using drift again? Everything should be differentiable (and we tried anywhere from 0 to many, many stop gradients around all error related terms and couldn't seem to get anything to happen).

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions