pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

mean_accept_prob significantly different after warmup #1786

Open jonny-so opened 2 months ago

jonny-so commented 2 months ago

I notice that after warmup, the mean_accept_prob significantly higher than both target_accept_prob and the mean_accept_prob observed during warmup, even on a trivial isotropic gaussian example. Minimum working example:

import jax.numpy as jnp
from jax.lax import scan
from numpyro.infer.hmc import hmc

def potential(x):
    return 0.5 * jnp.sum(x**2)

d = 10
nwarmup = 100000
nsamples = 100000

init_kernel, sample_kernel = hmc(potential, algo='HMC')
hmc_state = init_kernel(init_params=jnp.zeros(d), num_warmup=nwarmup, adapt_step_size=True, adapt_mass_matrix=False)

hmc_state = scan(lambda s, _: (sample_kernel(s), None), hmc_state, None, length=nwarmup)[0]
print("post warmup", hmc_state.mean_accept_prob)

hmc_state = scan(lambda s, _: (sample_kernel(s), None), hmc_state, None, length=nsamples)[0]
print("post samples", hmc_state.mean_accept_prob)

outputs:

post warmup 0.7992318
post samples 0.97852784

am I misusing something here?

fehiepsi commented 2 months ago

In the early phase, I guess the sampler tends to reject many samples. Hence you can see the smaller accept_prob than in the sampling phase. We use dual averaging to adapt step size and update the step size at the end of the warm-up phase, https://github.com/pyro-ppl/numpyro/blob/2f1bccdba2fc7b0a6ec235ca1bd5ce2417a0635c/numpyro/infer/hmc_util.py#L663

jonny-so commented 2 months ago

I see that they won't be the same, but the eventual accept rate is almost 100% suggesting the learned step size is too small. Note that I am targeting the default accept rate of 80%. Could this be the same issue discussed by the stan guys here? https://github.com/stan-dev/stan/issues/3105.

fehiepsi commented 2 months ago

You're right - the step size seems to be small. I'll look into the adaptation dynamic later this week. If you are interested, you can extract more information from scan body function, like step_size, accept_prob, etc.

fehiepsi commented 1 month ago

@jonny-so This turns out to be the issue of the dual averaging algorithm that we used

import jax.numpy as jnp
from jax.lax import scan
from numpyro.infer.hmc import hmc

def potential(x):
    return 0.5 * jnp.sum(x**2)

d = 10
nwarmup = 10000
nsamples = 10000

init_kernel, sample_kernel = hmc(potential, algo='HMC')
hmc_state = init_kernel(init_params=jnp.zeros(d), num_warmup=nwarmup, adapt_step_size=True, adapt_mass_matrix=False)

hmc_state_warmup, step_sizes = scan(lambda s, _: (sample_kernel(s), s.adapt_state.step_size), hmc_state, None, length=nwarmup)
print("post warmup", hmc_state_warmup.mean_accept_prob)

hmc_state = scan(lambda s, _: (sample_kernel(s), None), hmc_state_warmup, None, length=nsamples)[0]
print("post samples", hmc_state.mean_accept_prob)

print("exp(mean(log(last_50_step_sizes)))", jnp.exp(jnp.log(step_sizes[-50:]).mean()))
print("mean(last_50_step_sizes)", step_sizes[-50:].mean())
post warmup 0.7974499
post samples 0.97814786
exp(mean(log(last_50_step_sizes))) 0.8056808
mean(last_50_step_sizes) 1.37691

We use dual averaging over the last window buffer (50 steps) of the warmup phase. With that, the estimation for step_size is biased (exp_mean_log <= mean)). Let me think a bit more about what we can do here. We can expose a configuration to modify the length of the last window buffer, so that the estimation is better. What do you think?

cc @martinjankowiak do you have any suggestions dealing with this issue?

martinjankowiak commented 1 month ago

i'm not sure but if you wanted to reduce that specific bias i guess you could use the formula for the mean of a log normal distribution....

log_step_sizes = jnp.log(step_sizes[-50:])
jnp.exp(log_step_sizes.mean() + 0.5 * log_step_sizes.var())
fehiepsi commented 1 month ago

It looks like the implementation agrees with Algorithm 5 in https://jmlr.org/papers/volume15/hoffman14a/hoffman14a.pdf#page=18.62 I guess it is better to let users control the last window size.

jonny-so commented 1 month ago

Sorry for the delay, I've been flat out for the neurips deadline. I need to think about this a bit, but I'm taking a week off to recover... I'll come back to you soon.

jonny-so commented 2 weeks ago

Increasing the fixed window size to > 50 does indeed seem to resolve the issue; some comments on the stan ticket I linked suggests they have observed the same. Exposing it as an option would be fine, although I am curious to know why fixing it to such a small value hasn't been a problem before.