Closed samuela closed 3 years ago
You can't mix continuous callbacks and InterpolatingAdjoint in most cases. Is ReverseDiffAdjoint fine here?
Oh I wasn't aware of that... Why is that not a safe combo? What should be done instead?
Adjoint methods, without a special correction, cannot differentiate through discontinuities introduced in the event handling code. Essentially if you cause a discontinuous change, there's a boundary term in the integral that needs to be added. Automatic differentiation will automatically handle this, but the adjoint methods which are defining and solving the adjoint equation don't do this without extra terms added to them which we haven't done yet. I have been planning to get to this but it's just a bit time consuming and tricky to handle since there's not really a good derivation out there.
The thing that makes this tricker is that it's only the case where you introduce discontinuities which require the extra term (integrator.u = ...
). IIRC there's a few packages which rely on the fact that this works. For example, when computing Poincure plots you only need the event detection to find the zero but you don't actually make a modification at the zero, so that's a case where we would be adding an error to a working code.
I hope we can get someone on this in the next month because it's a known wart with weird edge cases. We do mention it in the docs:
https://diffeq.sciml.ai/latest/analysis/sensitivity/#Choosing-a-Sensitivity-Algorithm
The methods which utilize altered differential equation systems only work on ODEs (without events), but work on any ODE solver.
But probably not prominently enough.
Hmm, I could envision how a boundary term could pop out of these discontinuities, but I'll have to think a bit more about exactly what it should be...
If you end up coming across any references that touch on this I'd def be interested in giving them a read.
There's a reference in https://github.com/SciML/DiffEqSensitivity.jl/issues/4#issuecomment-369219280 (and note: I could never find a single thing showing Sundials working with that case, so that seems to be a false lead, but that thesis at least has a decent derivation that an implementation could be done from, but it doesn't seem as clean as I'd expect it to be)
You can't mix continuous callbacks and InterpolatingAdjoint in most cases. Is ReverseDiffAdjoint fine here?
Would DiscreteCallback
be safe instead?
Not if you modify integrator.u
. There is no discontinuity term w.r.t. t
any more but there still is w.r.t. u
, so you need to perturb lambda
by the derivative of the callback function itself.
Oh, right right. So (IIUC?) the thing that's missing then is "backprop" through the callback functions themselves, ie figure out where the callbacks happen in the forward pass, start from the end doing a callback-less adjoint solve until you hit a known callback time, at which point you pipe the adjoint at that time through the reverse AD of the callback function and continue with the next smooth segment.
Oh, right right. So (IIUC?) the thing that's missing then is "backprop" through the callback functions themselves, ie figure out where the callbacks happen in the forward pass, start from the end doing a callback-less adjoint solve until you hit a known callback time, at which point you pipe the adjoint at that time through the reverse AD of the callback function and continue with the next smooth segment.
So that was a lie. Here's a counterexample: https://gist.github.com/samuela/95dfd752c53e712a984e22832ca72fbe. In this case we try simulating a simple billiards ball hitting a bumper in 1-d (pilfered from the DiffTaichi paper). tl;dr is that just piping the adjoints through is not sufficient to get correct gradients.
This example raised the question: How can I get the times where my callback actually fired in the forwards solve
? I wasn't able to figure out any way to recover that information.
This example raised the question: How can I get the times where my callback actually fired in the forwards solve? I wasn't able to figure out any way to recover that information.
You'd have to create a cache array yourself there.
This one was fixed awhile back and we forgot to close it.
I'm running julia 1.5.1. Reproduction is: https://github.com/samuela/research/commit/e8287f3bd17f4a1dcebdf038c80681210007ee86. You can recreate the error with
] instantiate
,] build
, and thenjulia --project difftaichi/mass_spring.jl
. The code is much messier than I'd like but considering I don't understand this error, and it is specifically requested that I report it, I thought I'd just put this out there.