jeremiecoullon / SGMCMCJax

Lightweight library of stochastic gradient MCMC algorithms written in JAX.
https://sgmcmcjax.readthedocs.io/en/latest/index.html
Apache License 2.0
95 stars 8 forks source link

Optimiser for CV estimator #11

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

We could add an optimser for the CV estimator. This could be using Adam in JAX's optimizer module, or using Optax. The idea is that there is a good optimiser setup that's easy to use.

Example API:

adam_optimiser = build_adam_optimiser(loglikelihood, logprior, data, batch_size, compile=True)
opt_params, loss_array = adam_optimiser(key, Niters, param_IC)

On the other hand, perhaps this could be considered out of the scope of the library?

jeremiecoullon commented 3 years ago

Added this.