patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.44k stars 129 forks source link

ODE solver getting stuck for simple term #194

Open Binbose opened 1 year ago

Binbose commented 1 year ago

Hey, I am trying to implement a continuous normalizing flow, where the vector field is parameterized as the gradient of a potential function. Depending on the Adjoint method, either I am getting this error (for the default adjoint method):

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

or, for adjoint=diffrax.BacksolveAdjoint(), the ODE solver runs for some epochs and then suddenly gets stuck without any error messages (when digging into the debugger, it seems like as soon as the solver approaches t=0, it restarts over and over again), and even after letting it run for an hour it didn't make any progress (whereas the epochs before took some seconds maximally.) I tried different solvers (Dopri5, Kvaerno5(nonlinear_solver=NewtonNonlinearSolver())), plotted the vector field to make sure it is well behaving (it seems to be) and enabled float64, but nothing helped. If I parameterize the vector field directly as a vector-valued output of the NN it works for adjoint=diffrax.BacksolveAdjoint() (but the same error as above for the default adjoint method). I ported the code into a google colab here
Do you have any idea what the reasons might be for this behaviour?

patrick-kidger commented 1 year ago

The error you're getting due to the default adjoint method is because you're specifying max_steps=None. Set max_steps to some finite integer and you will be able to backpropagate.

Why is this necessary? JAX doesn't perform dynamic memory allocation, so in order to save the forward pass in memory (prior to backpropagation), then we need to allocate some memory to store them in, in advance. This memory has a size proportional to max_steps.

As such, memory usage will grow proportionally to max_steps. (Also, right now the runtime also currently increases logarithmically wrt max_steps, although that's something that should be fixed in an upcoming release.) So you should look to set this to the smallest number that works for your problem.

Regarding what you're saying about BacksolveAdjoint: this sounds a little odd. Can you try to construct a smaller MWE for the problem?

Side note, you may find the CNF example in the documentation to be interesting.

Binbose commented 1 year ago

That makes sense! After setting max_steps it does indeed not throw the error anymore. However, now it says RuntimeError: The maximum number of solver steps was reached. Try increasingmax_steps`, even after settingmax_steps` to something really high, like 10000 (I assume whatever causes this, also causes the problem for the other adjoint method, where the solver is simply not converging, and since max_steps was none, it just was hung in an infinite loop). Do you know what can cause these kinds of problems? Sure, I can shorten the colab to an MWE.

This code is partially built on the CNF example from the documentation and partially another CNF example. And generally, it does work well for a normal NN vector field, it only becomes a problem if the vector field is defined as the gradient of a potential function. Currently, I omit the dependence on time (so no hypernetwork), maybe this causes the non-convergence? (even though for the normal NN vector field it doesn't seem to be a problem and transforms the pdf correctly).

patrick-kidger commented 1 year ago

A few typical candidates for this kind of problem:

  1. You're using 32-bit rather than 64-bit, and do in fact need the extra precision to make things work.
  2. Your vector field is ill-posed, and is returning a nan output at some point. If this happens then Diffrax assumes that it made too large a timestep (into a region where the vector field isn't defined), and tries to shrink the timestep.
  3. Since your vector field is a grad-NN: make sure your activation functions have Lipschitz derivative (i.e. are smooth enough). In particular that they're not a ReLU. (If they were ReLU then your vector field wouldn't even be continuous.)
  4. You really are trying to solve something very stiff, and really do need that many timesteps.
    • Typical advice here for ODEs would be to switch to a stiff integrator (such as Kvaerno5), to set a large value for max_steps, and to use a PI rather than an I controller (PIDController(pcoeff=..., icoeff=...)).
    • In the case of a CNF then increasing max steps, and using a PI controller, continue to be good advice. Probably Kvaerno5 wouldn't scale well to the large state of the system, though: something like a predictor-corrector method may end up being preferred. Diffrax doesn't implement any predictor-corrector methods yet but they'd be easy enough to add.

Some approaches you can take to debug this:

  1. Run sol = diffeqsolve(..., throw=False); print(sol.stats). See if you're getting a lot of rejected steps or something.
  2. Run diffeqsolve(..., saveat=SaveAt(steps=True)) and see where the steps are being placed.
  3. Add jax.debug.{breakpoint,print} statements in the internals of Diffrax, and see what's going on!
  4. Try 32 vs 64 bit.

If it ends up being a too-stiff or 64-bit issue, then one solution may be to regularise your dynamics so that they're better behaved.

FWIW I've played with grad-NN CNFs before, and I recall having some issues a bit like what you're describing: that things exploded near the origin. In my case, I think I tracked it down to being a numerical issue: that the gradient structure somehow meant that I got a huge Lipschitz constant when near the origin, and that the solver struggled as a result. (It might be some other issue for you of course, though.)