Open jonny-so opened 7 months ago
In the early phase, I guess the sampler tends to reject many samples. Hence you can see the smaller accept_prob than in the sampling phase. We use dual averaging to adapt step size and update the step size at the end of the warm-up phase, https://github.com/pyro-ppl/numpyro/blob/2f1bccdba2fc7b0a6ec235ca1bd5ce2417a0635c/numpyro/infer/hmc_util.py#L663
I see that they won't be the same, but the eventual accept rate is almost 100% suggesting the learned step size is too small. Note that I am targeting the default accept rate of 80%. Could this be the same issue discussed by the stan guys here? https://github.com/stan-dev/stan/issues/3105.
You're right - the step size seems to be small. I'll look into the adaptation dynamic later this week. If you are interested, you can extract more information from scan
body function, like step_size
, accept_prob
, etc.
@jonny-so This turns out to be the issue of the dual averaging algorithm that we used
import jax.numpy as jnp
from jax.lax import scan
from numpyro.infer.hmc import hmc
def potential(x):
return 0.5 * jnp.sum(x**2)
d = 10
nwarmup = 10000
nsamples = 10000
init_kernel, sample_kernel = hmc(potential, algo='HMC')
hmc_state = init_kernel(init_params=jnp.zeros(d), num_warmup=nwarmup, adapt_step_size=True, adapt_mass_matrix=False)
hmc_state_warmup, step_sizes = scan(lambda s, _: (sample_kernel(s), s.adapt_state.step_size), hmc_state, None, length=nwarmup)
print("post warmup", hmc_state_warmup.mean_accept_prob)
hmc_state = scan(lambda s, _: (sample_kernel(s), None), hmc_state_warmup, None, length=nsamples)[0]
print("post samples", hmc_state.mean_accept_prob)
print("exp(mean(log(last_50_step_sizes)))", jnp.exp(jnp.log(step_sizes[-50:]).mean()))
print("mean(last_50_step_sizes)", step_sizes[-50:].mean())
post warmup 0.7974499
post samples 0.97814786
exp(mean(log(last_50_step_sizes))) 0.8056808
mean(last_50_step_sizes) 1.37691
We use dual averaging over the last window buffer (50 steps) of the warmup phase. With that, the estimation for step_size
is biased (exp_mean_log <= mean)). Let me think a bit more about what we can do here. We can expose a configuration to modify the length of the last window buffer, so that the estimation is better. What do you think?
cc @martinjankowiak do you have any suggestions dealing with this issue?
i'm not sure but if you wanted to reduce that specific bias i guess you could use the formula for the mean of a log normal distribution....
log_step_sizes = jnp.log(step_sizes[-50:])
jnp.exp(log_step_sizes.mean() + 0.5 * log_step_sizes.var())
It looks like the implementation agrees with Algorithm 5 in https://jmlr.org/papers/volume15/hoffman14a/hoffman14a.pdf#page=18.62 I guess it is better to let users control the last window size.
Sorry for the delay, I've been flat out for the neurips deadline. I need to think about this a bit, but I'm taking a week off to recover... I'll come back to you soon.
Increasing the fixed window size to > 50 does indeed seem to resolve the issue; some comments on the stan ticket I linked suggests they have observed the same. Exposing it as an option would be fine, although I am curious to know why fixing it to such a small value hasn't been a problem before.
I spoke with some of the stan developers recently, and they said this has come up a number of times before, and that the default should probably just be bigger. It would make a little more sense to me to have it proportional to the length of the warmup period, or something like that.
Hi @jonny-so, I think we can expose this configuration in HMC/NUTS constructor. Do you want to make a PR for this?
I notice that after warmup, the
mean_accept_prob
significantly higher than bothtarget_accept_prob
and themean_accept_prob
observed during warmup, even on a trivial isotropic gaussian example. Minimum working example:outputs:
am I misusing something here?