Open lockwo opened 2 months ago
Follow up, I actually just tried with a trivial implementation of Heun and it's also not working. Also, the mean/square stuff has no impact as well, tested without it.
Follow follow up, (diffrax) real Heun doesn't work, that is to say, gradients of Heun and finite difference don't match up. Now I am confused. Finite difference is extremely stable, matches the primal exactly and shows consistent gradients from 1e-2 to 1e-15 and more.
If I decrease the tolerance, I see both matching up. Only at large tolerances do they disagree.
Looking deeper, I thought it might be a situation like 4.1.2.4 of https://arxiv.org/abs/2406.09699 where AD is numerically wrong (see also https://github.com/ODINN-SciML/DiffEqSensitivity-Review/blob/main/code/SensitivityForwardAD/testgradient_python.py), since jacrev and jacfwd work. However, they argue this is true for any tolerance, whereas I see it only for large tolerances. Maybe the solution is just don't use large tolerances with first order methods? But my confusion is that this should be differential. Also, the paper said it works in Sensitivity in Julia, so we implemented it in Julia and also saw its wrong (which is extra surprising because the finite diff trajectories are basically identical to the reverse diff trajectories.
There was some good discussion in https://github.com/SciML/SciMLSensitivity.jl/issues/1094. Given that clearly isn't a fault of diffrax (or the Julia sciml ecosystem), the original points in my issue aren't as relevant. But maybe this could be in the docs somewhere? Or just a reference to numerical vs algorithmic accuracy considerations? As someone not super knowledgable on the discrete vs. continuous adjoints, this was a tough nut to crack so I'd like to spare some future person the amount of work we put into this if possible lol.
Ah, you're bumping into the esoteric end of the autodiff literature!
An FAQ entry sounds reasonable.
We are encountering gradients that are incorrect in specific regime. Specifically, we have:
Below is a simplified example. Basically, we just take Euler and do some trivial change for the sake of example (we have a more complicated solver, but have identified the root of the issue to be this here), but crucially it has a y error that depends on a recalculation of the drift function (note that with or without the stop gradients doesn't matter). There doesn't seem to be anything wrong with the PIDController since we also implemented a simple controlled and the same error shows up. If constant stepping is used, the gradients are accurate. Note that our finite difference is stable and we have tried epsilon from 1e-10 to 1e-3 and it shows consistent results. The primal values are correct, but there is a difference in the gradient.
prints
We see accurate primal, but inaccurate gradients (by enough that this cannot just be numerical noise, we have tried on an other problems and see larger differences as well). The error order is wrong too, but that shouldn't matter, since we should just converge wrong, not change the differentiability of it. Are we violating some requirement by using drift again? Everything should be differentiable (and we tried anywhere from 0 to many, many stop gradients around all error related terms and couldn't seem to get anything to happen).