patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.45k stars 132 forks source link

What happened to bounded_while_loop? #522

Closed dionhaefner closed 2 weeks ago

dionhaefner commented 2 weeks ago

There used to be an implementation of bounded_while_loop in this repo which implements recursive checkpointing for arbitrary JAX loops. I can't see it in the current version of the repo; were there fundamental problems with it? I'm looking for something similar to do recursive checkpointing on general JAX functions (that don't use equinox / diffrax under the hood).

patrick-kidger commented 2 weeks ago

This is now available as a library function under equinox.internal.while_loop! It's not documented because it has a couple of footguns -- see its docstring -- that made it more complicated to use than I was willing to officially support. But it's still 'semi-public' so you should feel free to use it.

The implementation that used to live here is under equinox.internal.while_loop(..., kind="bounded"), but practically speaking I now recommend almost never using that, and preferring equinox.internal.while_loop(..., kind="checkpointed"), which is much faster.

dionhaefner commented 2 weeks ago

Awesome, thank you!