patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 130 forks source link

Incorrect gradient in toy adaptive ODE #499

Open lockwo opened 3 months ago

lockwo commented 3 months ago

We are encountering gradients that are incorrect in specific regime. Specifically, we have:

Below is a simplified example. Basically, we just take Euler and do some trivial change for the sake of example (we have a more complicated solver, but have identified the root of the issue to be this here), but crucially it has a y error that depends on a recalculation of the drift function (note that with or without the stop gradients doesn't matter). There doesn't seem to be anything wrong with the PIDController since we also implemented a simple controlled and the same error shows up. If constant stepping is used, the gradients are accurate. Note that our finite difference is stable and we have tried epsilon from 1e-10 to 1e-3 and it shows consistent results. The primal values are correct, but there is a difference in the gradient.

import jax

jax.config.update("jax_enable_x64", True)
import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar
import equinox as eqx
from equinox.internal import ω

class Test(diffrax.AbstractItoSolver):
    term_structure: ClassVar = diffrax.AbstractTerm
    interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation

    def order(self, terms):
        return 1

    def strong_order(self, terms):
        return 0.5

    def init(
        self,
        terms,
        t0,
        t1,
        y0,
        args,
    ):
        return None

    def func(
        self,
        terms,
        t0,
        y0,
        args,
    ):
        return terms.vf(t0, y0, args)

    def step(
        self,
        terms,
        t0,
        t1,
        y0,
        args,
        solver_state,
        made_jump,
    ):
        del made_jump
        control = terms.contr(t0, t1)
        y1 = (y0**ω + terms.vf_prod(t0, y0, args, control) ** ω).ω

        drift = terms
        b = jax.lax.stop_gradient(drift.vf(t0, y0, args))
        y_error = jax.lax.stop_gradient(jnp.linalg.norm(b) * (t1 - t0))

        dense_info = dict(y0=y0, y1=y1)
        return y1, y_error, dense_info, solver_state, diffrax.RESULTS.successful

t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])
tol = 1e-1
solver = Test()
cont = diffrax.PIDController(tol, tol, error_order=1.0)

def drift(t, X, args):
    y1, y2 = X
    dy1 = -273 / 512 * y1
    dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
    return jnp.array([dy1, dy2])

def solve(key, y0):
    terms = diffrax.ODETerm(drift)
    saveat = diffrax.SaveAt(t1=True)
    sol = diffrax.diffeqsolve(
        terms,
        solver,
        t0,
        t1,
        dt0=0.0001,
        y0=y0,
        saveat=saveat,
        max_steps=1000,
        stepsize_controller=cont,
        adjoint=diffrax.RecursiveCheckpointAdjoint(),
    )
    return sol

def loss(y):
    k = jax.random.key(0)
    s = solve(k, y)
    return jnp.sqrt(jnp.mean(s.ys ** 2)), s.stats

x0 = jnp.array([1.0, 1.0])
print(eqx.filter_value_and_grad(loss, has_aux=True)(x0))

def finite_diff(y):
    eps = 1e-9
    val1 = loss(jnp.array([y[0] + eps / 2, y[1]]))[0]
    val2 = loss(jnp.array([y[0], y[1] + eps / 2]))[0]
    val3 = loss(jnp.array([y[0] - eps / 2, y[1]]))[0]
    val4 = loss(jnp.array([y[0], y[1] - eps / 2]))[0]
    print(val1, val2, val3, val4)
    return jnp.array([val1 - val3, val2 - val4]) / eps

print(finite_diff(x0))

prints

((Array(81.89217529, dtype=float64),
  {'max_steps': 1000,
   'num_accepted_steps': Array(682, dtype=int64, weak_type=True),
   'num_rejected_steps': Array(44, dtype=int64, weak_type=True),
   'num_steps': Array(726, dtype=int64, weak_type=True)}),
 Array([-60.26947042, 142.16164571], dtype=float64))

81.8921752553513 81.89217537306844 81.89217532865639 81.89217521093927
Array([-73.30508822, 162.12916876], dtype=float64)

We see accurate primal, but inaccurate gradients (by enough that this cannot just be numerical noise, we have tried on an other problems and see larger differences as well). The error order is wrong too, but that shouldn't matter, since we should just converge wrong, not change the differentiability of it. Are we violating some requirement by using drift again? Everything should be differentiable (and we tried anywhere from 0 to many, many stop gradients around all error related terms and couldn't seem to get anything to happen).

lockwo commented 3 months ago

Follow up, I actually just tried with a trivial implementation of Heun and it's also not working. Also, the mean/square stuff has no impact as well, tested without it.

Follow follow up, (diffrax) real Heun doesn't work, that is to say, gradients of Heun and finite difference don't match up. Now I am confused. Finite difference is extremely stable, matches the primal exactly and shows consistent gradients from 1e-2 to 1e-15 and more.

If I decrease the tolerance, I see both matching up. Only at large tolerances do they disagree.

lockwo commented 3 months ago

Looking deeper, I thought it might be a situation like 4.1.2.4 of https://arxiv.org/abs/2406.09699 where AD is numerically wrong (see also https://github.com/ODINN-SciML/DiffEqSensitivity-Review/blob/main/code/SensitivityForwardAD/testgradient_python.py), since jacrev and jacfwd work. However, they argue this is true for any tolerance, whereas I see it only for large tolerances. Maybe the solution is just don't use large tolerances with first order methods? But my confusion is that this should be differential. Also, the paper said it works in Sensitivity in Julia, so we implemented it in Julia and also saw its wrong (which is extra surprising because the finite diff trajectories are basically identical to the reverse diff trajectories.

Julia code + Results
```julia using OrdinaryDiffEq using FiniteDiff, ForwardDiff, Statistics, Zygote, ReverseDiff, SciMLSensitivity function f(X, args, t) y1, y2 = X dy1 = -273 / 512 * y1 dy2 = -1 / 160 * y1 - (-785 / 512 + sqrt(2) / 8) * y2 return [dy1, dy2] end u0 = [1.0, 1.0] args = ones(1) odeprob = ODEProblem(f, u0, (0.0, 3.0), args) function loss(u0) _prob = remake(odeprob, u0=u0) _sol = (solve(_prob, Heun(), dt=0.1, abstol=0.1, reltol=0.1, save_everystep=true, save_start = false, #adaptive=false controller=IController(), #CustomController(), sensealg=ReverseDiffAdjoint() )) @show (_sol.t) _sol = _sol[end] return sum(abs2, _sol) end function finite_diff(u0) eps = 1e-5 v1 = loss([u0[1] + eps / 2, u0[2]]) v2 = loss([u0[1] - eps / 2, u0[2]]) v3 = loss([u0[1], u0[2] + eps / 2]) v4 = loss([u0[1], u0[2] - eps / 2]) [v1 - v2, v3 -v4] / (eps) end begin println("FiniteDiff") @show finite_diff(u0) #dp1 = FiniteDiff.finite_difference_gradient(loss, u0) println("Forward") dp2 = ForwardDiff.gradient(loss, u0) println("Reverse") dp3 = Zygote.gradient(loss, u0)[1] @show dp1 dp2 dp3 end ``` ``` FiniteDiff _sol.t = [0.1, 0.613929509825019, 1.199780114366179, 1.7608249943840808, 2.28993033847855, 2.797764017021789, 3.0] _sol.t = [0.1, 0.6139295972860885, 1.1997803015279263, 1.7608252770250412, 2.2899306781168063, 2.7977643468051436, 3.0] _sol.t = [0.1, 0.6139290287896685, 1.1997791292067386, 1.7608237043621409, 2.289928860088721, 2.797762422397756, 3.0] _sol.t = [0.1, 0.6139301220531311, 1.1997813360415173, 1.7608266217853026, 2.289932183276105, 2.7977659413541907, 3.0] finite_diff(u0) = [7.376464645858504, 4823.877383614672] Forward _sol.t = [0.1, 0.728028559619219, 1.4979699401846973, 2.275540178765593, 3.0] Reverse _sol.t = [0.1, 0.6139295972860885, 1.1997802572988525, 1.760825190439614, 2.2899305515863184, 2.7977642464962393, 3.0] dp1 = [3.4720645693875793, 4831.212995065322] dp2 = [-11.629850313906191, 3561.5179283293633] dp3 = [-15.024234205512382, 4588.183943938828] ```
lockwo commented 2 months ago

There was some good discussion in https://github.com/SciML/SciMLSensitivity.jl/issues/1094. Given that clearly isn't a fault of diffrax (or the Julia sciml ecosystem), the original points in my issue aren't as relevant. But maybe this could be in the docs somewhere? Or just a reference to numerical vs algorithmic accuracy considerations? As someone not super knowledgable on the discrete vs. continuous adjoints, this was a tough nut to crack so I'd like to spare some future person the amount of work we put into this if possible lol.

patrick-kidger commented 2 months ago

Ah, you're bumping into the esoteric end of the autodiff literature!

An FAQ entry sounds reasonable.