blackjax-devs / sampling-book

Tutorials and sampling algorithm comparisons
https://blackjax-devs.github.io/sampling-book
68 stars 13 forks source link

MLP Classifier Bug #39

Closed PaulScemama closed 1 year ago

PaulScemama commented 1 year ago

When running through the code in a local notebook, I get the following error from this code block:

from fastprogress.fastprogress import progress_bar

import blackjax
from blackjax.sgmcmc.gradients import grad_estimator

data_size = len(y_train)
batch_size = 512
step_size = 4.5e-5

num_warmup = (data_size // batch_size) * 20
num_samples = 1000

# Batch the data
rng_key = jax.random.PRNGKey(1)
batches = batch_data(rng_key, (X_train, y_train), batch_size, data_size)

# Set the initial state
state = jax.jit(model.init)

# (rng_key, jnp.ones(X_train.shape[-1]))

# Build the SGLD kernel with a constant learning rate
grad_fn = grad_estimator(logprior_fn, loglikelihood_fn, data_size)
sgld = blackjax.sgld(grad_fn, learning_rate=.01)

# Sample from the posterior
accuracies = []
steps = []

pb = progress_bar(range(num_warmup))
for step in pb:
    _, rng_key = jax.random.split(rng_key)
    batch = next(batches)
    state = jax.jit(sgld)(rng_key, state, batch, step_size)
    if step % 100 == 0:
        accuracy = compute_accuracy(state, X_test, y_test)
        accuracies.append(accuracy)
        steps.append(step)
        pb.comment = f"| error: {100*(1-accuracy): .1f}"
TypeError: Expected a callable value, got SamplingAlgorithm(init=.init_fn at 0x7f038c23ea70>, step=.step_fn at 0x7f038c23d990>)

which is complaining about the following line

# Set the initial state
state = jax.jit(model.init)(rng_key, jnp.ones(X_train.shape[-1]))
albcab commented 1 year ago

Until we cut a new release, you'll want to pip install blackjax-nightly to use the latest developments.

To initialize a SGLD kernel one needs to specify a schedule function, which returns a step size at each sampling step, and a gradient estimator function. Here for a constant step size, and data_size data samples:

grad_fn = blackjax.sgmcmc.gradients.grad_estimator(logprior_fn, loglikelihood_fn, data_size)

We can now initialize the sgld kernel and the state:

sgld = blackjax.sgld(grad_fn)

Assuming we have an iterator batches that yields batches of data we can perform one step:

step_size = 1e-3
minibatch = next(batches)
new_position = sgld.step(rng_key, position, minibatch, step_size)

Kernels are not jit-compiled by default so you will need to do it manually:

step = jax.jit(sgld.step)
new_position, info = step(rng_key, position, minibatch, step_size)

Do this only once, i.e. outside the for loop. #486

You can also initialise the state and use it to iterate:

state = sgld.init(position)
new_state = sgld.step(rng_key, state, minibatch, step_size)

but this is not necessary for SGLD.

@jmsull hope this helps. The MLP classifier example notebook needs updating... great first contribution if you feel like taking care of it :smile: