Open bagibence opened 3 months ago
We could define super-class class that is a general interface for stochastic solvers (with inner/outer loop structure and the inner loop should be an abstract class, as well as the run_iterator
for batching). The class should still provide the run
method for data fitting in memory.
The example in the docs currently uses a custom loop to implement stochastic gradient descent.
An alternative would be to make use of
jaxopt.StochasticSolver.run_iterator
and add support for stochastic solvers, potentially including optimizers implemented in Optax through thejaxopt.OptaxSolver
wrapper.Additionally, for data that fits in memory, adding a faster version of this loop -- which implements sampling mini-batches and updating parameters (similarly to
ProxSVRG.run
in https://github.com/flatironinstitute/nemos/pull/184) -- could be useful.