Open dflocher opened 7 months ago
Ah, this is a known (although pretty obscure) limitation of using autodifferentiable differential equation solvers with discontinuous vector fields.
TL;DR: the solution
First of all, the solution: explicitly declare the jump time in the stepsize controller, typically by doing stepsize_controller=PIDController(..., jumps_ts=[T])
. (I think you could also use StepTo
here if you wanted a fixed step size rather than an adaptive one.)
What went wrong?
As for what's going on, we can explain this in a few different ways.
T
is only used to generate some boolean masks (t <= T
, t > T
). There's no way for a gradient to flow backwards from d_y
into T
; boolean masks never have gradients.T
, as it's never used in a differentiable way.d/dT \int_0^t f(s, y(s), T) ds
, i.e. a derivative-of-an-integral. (Here f
is your vector field; the RHS of your ODE, and s
is the evolving time of the system.) However, because f
is discontinuous, we cannot switch the derivative and the integral: there is no meaninngful notion of df/dT
. And having a meaningful notion of df/dT
is what we are relying upon when differentiating solver!Why does the solution above work?
So how do we fix this? We've seen that writing diffeqsolve(..., t0=0, t1=t)
doesn't work.
Our first insight into fixing this is to observe that if we had split this into diffeqsolve(..., t0=0, t1=T)
and diffeqsolve(..., t0=jnp.minimum(t, T), t1=t)
, then we would have a well-behaved ODE on each piece.
In fact, go ahead and test this, and you'll get the expected gradient! Reasoning in terms of the computation graphs described above, we can see that the reason for this is that T
now has a differentiable dependence in the computation graph.
So using PIDController(..., jump_ts=[T])
basically does exactly this: it means we make a numerical step right to that point. (And will also be faster to compile than the double-diffeqsolve
approach -- which doubles the number of operations JAX has to compile.)
Can we do better?
This is an unfortunate user footgun! But I don't know of an automatic solution to this; so far as I know it may be an open question in the theory of autodifferentiation. (?)
Spitballing, I imagine this could maybe be solved by having the ODE solver try and detect when it thinks a jump has occured, if so to solve a root-finding problem to find the jump, and then use that in its step size control.
I think investigating this might be an interesting research question in autodifferentiation, for those curious enough to try :)
Thank you, Patrick, for your detailed and instructive answer! Your proposed solution works fine. Best, David
Hi, applying jax.grad to a function which uses diffrax to integrate a piecewise defined ODE, I observe that one partial derivative is unexpectedly zero. The ODE solver returns correct function values, just the gradient is wrong. I’m wondering whether this is a bug, or whether I’m doing something wrong. Thanks in advance! David
Example:
Consider the piecewise defined ODE
$$\frac{dy}{dt} = -k(t) \cdot y, \qquad y(0) = y_0, \qquad \mathrm{with} \ k(t) = \begin{cases} k_0, \ t \leq T \ 0, \ t > T \end{cases},$$
to which the solution reads
$$y(t) = y_0 \cdot \begin{cases} e^{-k_0 t}, \ t \leq T \ e^{-k_0 T}, \ t > T \end{cases}$$
I'm interested in the partial derivatives w.r.t. $T$ and $k_0$. In the code example below, I compare the gradient obtained from integrating the ODE using diffrax to the analytical solution and to a finite difference calculation.
prints out the following:
I'm using Python 3.11.7, jax 0.4.23, jaxlib 0.4.23.dev20231223, diffrax 0.5.0, MacOS 14.2.1, x86_64, running on CPU