jeremiecoullon / SGMCMCJax

Lightweight library of stochastic gradient MCMC algorithms written in JAX.
https://sgmcmcjax.readthedocs.io/en/latest/index.html
Apache License 2.0
95 stars 8 forks source link

BAOAB and BADODAB samplers seem broken? #69

Open zfurman56 opened 1 year ago

zfurman56 commented 1 year ago

Using the code provided in this example notebook for either BAOAB or BADODAB gives samples that are constant, even with large dt. I haven't been able to get them to give non-constant output. Is this a bug, or is that notebook just out-of-date?

Output from BADODAB:

Screen Shot 2023-09-07 at 4 31 45 PM

Output from SGNHT, for comparison:

Screen Shot 2023-09-07 at 4 32 44 PM

Minimal reproducing code, available in this Colab notebook:

import matplotlib.pyplot as plt
import numpy as np

import jax.numpy as jnp 
from jax import random

from sgmcmcjax.models.logistic_regression import gen_data, loglikelihood, logprior
from sgmcmcjax.samplers import build_badodab_sampler

key = random.PRNGKey(42)
dim = 10
Ndata = 100000

theta_true, X, y_data = gen_data(key, dim, Ndata)

data = (X, y_data)

batch_size = int(0.01*X.shape[0])

my_sampler = build_badodab_sampler(1e-3, loglikelihood, logprior, data, batch_size)

key = random.PRNGKey(0)
Nsamples = 10000
samples = my_sampler(key, Nsamples, theta_true)

idx = 7
plt.plot(np.array(samples)[:,idx])
plt.axhline(theta_true[idx], c='r')
plt.show()

Replacing build_badodab_sampler with build_sgnht_sampler (and changing dt to 1e-5) works as expected.

zfurman56 commented 1 year ago

Bisected and traced the issue to 1a43e5ea391139a52b634f68432e76f88cb2b1c5, which seems to have introduced a typo while refactoring. Submitted a PR to fix.

jeremiecoullon commented 11 months ago

Hi @zfurman56, sorry for never replying to this I was on holiday at the time and then forgot to get back to it 😓

Thanks very much for this PR I'll have a look this week!

jeremiecoullon commented 11 months ago

Just merged your PR into master; thanks very much for this fix!