jeremiecoullon / SGMCMCJax

Lightweight library of stochastic gradient MCMC algorithms written in JAX.
https://sgmcmcjax.readthedocs.io/en/latest/index.html
Apache License 2.0
95 stars 8 forks source link

Palidrome diffusion solvers #6

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

How to add diffusion solvers that are palidromes (ex: BAOAB or NOGIN)? These will usually (perhaps always?) need a gradient update halfway through the solver, so one solution is to have the update function in 2 parts. That way you calculate the gradient in between.

Example:


init_fn, update1, update2, get_params = baoab(1e-5)

for i in range(Nsamples):
  key, subkey1, subkey2 = random.split(key, 3)
  state = update1(i, subkey1, mygrad, state)
  mygrad = grad_log_post(get_params(state), *data) 
  state = update2(i, subkey2, mygrad, state)
  samples.append(get_params(state))

Some issues:

jeremiecoullon commented 3 years ago

This is merge in master. I created a "diffusion_factory" that generation diffusion decorators. This factory takes in arguments (is_palidrome and is_sghmc). The random key is passed in both updates (even though it seems that the second upate might never need randomness.