Closed PaulScemama closed 1 year ago
Until we cut a new release, you'll want to pip install blackjax-nightly
to use the latest developments.
To initialize a SGLD kernel one needs to specify a schedule function, which returns a step size at each sampling step, and a gradient estimator function. Here for a constant step size, and data_size
data samples:
grad_fn = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)
We can now initialize the sgld kernel and the state:
sgld = blackjax.sgld(grad_fn)
Assuming we have an iterator batches
that yields batches of data we can perform one step:
step_size = 1e-3
minibatch = next(batches)
new_position = sgld.step(rng_key, position, minibatch, step_size)
Kernels are not jit-compiled by default so you will need to do it manually:
step = jax.jit(sgld.step)
new_position, info = step(rng_key, position, minibatch, step_size)
Do this only once, i.e. outside the for loop. #486
You can also initialise the state and use it to iterate:
state = sgld.init(position)
new_state = sgld.step(rng_key, state, minibatch, step_size)
but this is not necessary for SGLD.
@jmsull hope this helps. The MLP classifier example notebook needs updating... great first contribution if you feel like taking care of it :smile:
When running through the code in a local notebook, I get the following error from this code block:
which is complaining about the following line