google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

custom loop pjit example #365

Closed fabianp closed 1 year ago

fabianp commented 1 year ago

code by @fllinares, reviewed in #358