Open Binbose opened 1 year ago
The error you're getting due to the default adjoint method is because you're specifying max_steps=None
. Set max_steps
to some finite integer and you will be able to backpropagate.
Why is this necessary? JAX doesn't perform dynamic memory allocation, so in order to save the forward pass in memory (prior to backpropagation), then we need to allocate some memory to store them in, in advance. This memory has a size proportional to max_steps
.
As such, memory usage will grow proportionally to max_steps
. (Also, right now the runtime also currently increases logarithmically wrt max_steps
, although that's something that should be fixed in an upcoming release.) So you should look to set this to the smallest number that works for your problem.
Regarding what you're saying about BacksolveAdjoint
: this sounds a little odd. Can you try to construct a smaller MWE for the problem?
Side note, you may find the CNF example in the documentation to be interesting.
That makes sense! After setting max_steps
it does indeed not throw the error anymore. However, now it says RuntimeError: The maximum number of solver steps was reached. Try increasing
max_steps`, even after setting
max_steps` to something really high, like 10000 (I assume whatever causes this, also causes the problem for the other adjoint method, where the solver is simply not converging, and since max_steps was none, it just was hung in an infinite loop).
Do you know what can cause these kinds of problems?
Sure, I can shorten the colab to an MWE.
This code is partially built on the CNF example from the documentation and partially another CNF example. And generally, it does work well for a normal NN vector field, it only becomes a problem if the vector field is defined as the gradient of a potential function. Currently, I omit the dependence on time (so no hypernetwork), maybe this causes the non-convergence? (even though for the normal NN vector field it doesn't seem to be a problem and transforms the pdf correctly).
A few typical candidates for this kind of problem:
nan
output at some point. If this happens then Diffrax assumes that it made too large a timestep (into a region where the vector field isn't defined), and tries to shrink the timestep.Kvaerno5
), to set a large value for max_steps
, and to use a PI rather than an I controller (PIDController(pcoeff=..., icoeff=...)
).Kvaerno5
wouldn't scale well to the large state of the system, though: something like a predictor-corrector method may end up being preferred. Diffrax doesn't implement any predictor-corrector methods yet but they'd be easy enough to add.Some approaches you can take to debug this:
sol = diffeqsolve(..., throw=False); print(sol.stats)
. See if you're getting a lot of rejected steps or something.diffeqsolve(..., saveat=SaveAt(steps=True))
and see where the steps are being placed.jax.debug.{breakpoint,print}
statements in the internals of Diffrax, and see what's going on!If it ends up being a too-stiff or 64-bit issue, then one solution may be to regularise your dynamics so that they're better behaved.
FWIW I've played with grad-NN CNFs before, and I recall having some issues a bit like what you're describing: that things exploded near the origin. In my case, I think I tracked it down to being a numerical issue: that the gradient structure somehow meant that I got a huge Lipschitz constant when near the origin, and that the solver struggled as a result. (It might be some other issue for you of course, though.)
Hey, I am trying to implement a continuous normalizing flow, where the vector field is parameterized as the gradient of a potential function. Depending on the Adjoint method, either I am getting this error (for the default adjoint method):
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.
or, for
adjoint=diffrax.BacksolveAdjoint()
, the ODE solver runs for some epochs and then suddenly gets stuck without any error messages (when digging into the debugger, it seems like as soon as the solver approaches t=0, it restarts over and over again), and even after letting it run for an hour it didn't make any progress (whereas the epochs before took some seconds maximally.) I tried different solvers (Dopri5, Kvaerno5(nonlinear_solver=NewtonNonlinearSolver())), plotted the vector field to make sure it is well behaving (it seems to be) and enabled float64, but nothing helped. If I parameterize the vector field directly as a vector-valued output of the NN it works foradjoint=diffrax.BacksolveAdjoint()
(but the same error as above for the default adjoint method). I ported the code into a google colab hereDo you have any idea what the reasons might be for this behaviour?