Closed dionhaefner closed 2 weeks ago
This is now available as a library function under equinox.internal.while_loop
! It's not documented because it has a couple of footguns -- see its docstring -- that made it more complicated to use than I was willing to officially support. But it's still 'semi-public' so you should feel free to use it.
The implementation that used to live here is under equinox.internal.while_loop(..., kind="bounded")
, but practically speaking I now recommend almost never using that, and preferring equinox.internal.while_loop(..., kind="checkpointed")
, which is much faster.
Awesome, thank you!
There used to be an implementation of
bounded_while_loop
in this repo which implements recursive checkpointing for arbitrary JAX loops. I can't see it in the current version of the repo; were there fundamental problems with it? I'm looking for something similar to do recursive checkpointing on general JAX functions (that don't use equinox / diffrax under the hood).