Closed astanziola closed 1 year ago
First of all, thankyou for providing such a careful benchmark. Second, sorry for taking so long to get back to you -- tackling this has turned out to be an interesting problem, which turned out to take longer than I thought! (This also turned up two XLA bugs along the way: https://github.com/google/jax/issues/16663, https://github.com/google/jax/issues/16661.)
Anyway, I'm happy to say that as of #276, the performance is now much closer. On my V100:
scan: 0.042259128065779805
scan AD: 0.10428679105825722
SemiImplicitEuler: 0.04464115505106747
SemiImplicitEuler AD: 0.15305515099316835
Note that I did make one change to the benchmark to ensure a fair comparison: I switched Diffrax's stepsize_controller
to diffrax.StepTo
, as this is the appropriate analogue of lax.scan
. (In contrast, note how your current program depends only on t[0]
, t[1]
, t[-1]
, but not any of the other values of t
. This means that you could change just those values and get a different number of steps -- but you wouldn't have to recompile! This extra flexibility at runtime is responsible for part of the overhead you're measuring.)
As for the changes that I made: most of the overhead turned out to be due to the extra complexity of the recursive checkpointing (as opposed to simply checkpointing on every step). The relavant changes are in https://github.com/patrick-kidger/equinox/pull/415.
This improvement will appear in the next release of Diffrax. And there's clearly still a small discrepancy on the backward pass -- it looks like that still needs some more careful profiling. Let me know what this looks like for your actual use-case!
Thanks a lot, that looks great! I'm on holidays until the end of next week, but as soon as I'm back I'll give it a try on the actual simulator.
Sorry if this took a while, but just tested it in the simulator and it works amazingly :smile:
I get very good performances (roughly 2.2x for the backward pass). Interestingly enough, I get a very minor performance boost using ConstantStepsize
instead of StepTo
, but is really something negligible (but consistent).
I do have a couple of warnings at startup due to deprecation of equinox.filter_custom_vjp.defvjp
, but that's a different story.
Also... Did I just got a runtime XLA error based on array values ?? What is this new amazing sorcery :heart_eyes: ??
Thanks again for fixing this!
Marvellous!
The slight difference between ConstantStepSize
and StepTo
-- I have no explanation for this :D
I do have a couple of warnings at startup due to deprecation of equinox.filter_custom_vjp.defvjp, but that's a different story.
Yup, next release of Diffrax (in the next few days) will avoid that code path.
Also... Did I just got a runtime XLA error based on array values ?? What is this new amazing sorcery 😍 ??
Yes you did! Equinox recently added public support for these. Documentation available here. Indeed "sorcery" is the appropriate word, since this is something new under the JAX sun.
Just released the new version of Diffrax. I think everything discussed here should now work, fast, without warnings.
As such I'm closing this, but please do re-open it if you think this isn't fixed.
Hello again!
I am still on my quest to add proper integration + adjoint calculation (and checkpointing) to my wave simulator :smile:
I appreciate that
diffrax
offers a range of methods for calculating adjoints, each with its own trade-off between computational complexity and memory requirements.However, for smaller simulations, it might be beneficial to maximize checkpoint usage and potentially save the entire ODE trajectory for reverse-mode AD. This approach takes full advantage of the GPU memory, thereby reducing computational times.
My understanding is that this can be achieved by using
RecursiveCheckpointAdjoint
with a large value ofcheckpoints
, potentially as high as the number of steps in the forward integrator.I've attempted to implement this without much success. To be precise, while I am obtaining the correct numerical results, the computation times are far longer than expected.
Here is an MWE:
Where I get the following timings on an RTX 4000:
As expected, for
scan
, the AD calculation is roughly twice the execution time required by the forward pass. This can be made almost exactly 2x if thejax.checkpoint
decorator is removed.For the forward pass of
SemiImplicitEuler
, the timings I get are approximately twice those of the scan alone. However, this could easily be attributed to the more sophisticated implementation of thediffrax
integrator, so overall that's completely fine.However, the timings for performing AD are about 7x those required by the scan method. In a more complex example within my simulator, it can reach up to 30x the time required by the equivalent scan integrator.
Am I missing something about the correct approach to calculating the adjoint?
Also, I'm not sure if the
RecursiveCheckpointAdjoint
is using the samesolver
as the forward integrator (based on my understanding of the documentation, it isn't), and I can't seem to find a way to pass a specificsolver
to it. Is it be necessary to define a new class derived fromAbstractAdjoint
with a customloop
method to achieve this?Thanks a lot!