flatironinstitute / nemos

NEural MOdelS, a statistical modeling framework for neuroscience.
https://nemos.readthedocs.io/
MIT License
80 stars 7 forks source link

Make use of `jaxopt.StochasticSolver.run_iterator` #194

Open bagibence opened 3 months ago

bagibence commented 3 months ago

The example in the docs currently uses a custom loop to implement stochastic gradient descent.

An alternative would be to make use of jaxopt.StochasticSolver.run_iterator and add support for stochastic solvers, potentially including optimizers implemented in Optax through the jaxopt.OptaxSolver wrapper.

Additionally, for data that fits in memory, adding a faster version of this loop -- which implements sampling mini-batches and updating parameters (similarly to ProxSVRG.run in https://github.com/flatironinstitute/nemos/pull/184) -- could be useful.

BalzaniEdoardo commented 2 months ago

We could define super-class class that is a general interface for stochastic solvers (with inner/outer loop structure and the inner loop should be an abstract class, as well as the run_iterator for batching). The class should still provide the run method for data fitting in memory.