Open nhuurre opened 3 years ago
Thanks, @nhuurre. I thought our code already did this! I imagine you're right about why it's coded this way for simplicity and to keep transitions stateless. Transitions should be stateless and only depend on the previous state, but it's also OK to store functions of that previous state. It's not OK to store things like moving averages and condition jumps on that.
Given that gradients are the bottleneck for the whole algorithm, and this requires (L + 1) rather than L, the speedup should depend on the average number of gradient evaluations L per iteration. In really simple models, the other bits of the code are more of a proportion of run time than gradients. So the optimal speedup would be in a problem with easy geometry (low L) and expensive gradients.
Some history: when the Markov chain Monte Carlo code was designed so many years ago we had no idea what scope of Markov transitions we might implement down the road. This included potential improvements or variants of the original No-U-Turn and Hybrid Monte Carlo samplers that had been implemented for continuous spaces but also potential samplers for discrete spaces. In order to make the code as flexible as possible for these potential samplers the Markov transitions were abstracted as stateless transformations, with the Markov chain Monte Carlo code working with only a generic sample object. That way Markov transitions could incorporate whatever auxiliary information they needed into their internal states. Moreover this allowed for composite samplers where different transitions were applied to one common states. This is also why the CmdStan argument structure is so flexible; it had to accommodate all of the possible samplers that might have been incorporated in what was then the near future.
With the promise of other Markov transitions never being realized this generic functionality hasn't been used. At some point the very generic code that called all of the transitions was organized into the existing services library that forces iteration with a single Markov transition such that the external Markov chain Monte Carlo sample and internal Markov transition state are always aligned.
Given the current state of the code removing the suggested this->hamiltonian_.init(this->z_, logger)
call would be valid, but technically there's nothing prevention new API routes from calling transition
in different ways that would cause the external sample and internal state to drift. We could change the transition
signature, but that would require touching lots of other code and I think it would be more productive to consider a full redesign of the Markov chain Monte Carlo code around the current dynamic Hamiltonian Monte Carlo sampler instead of the full generality.
I think a safer and more precise change would be to add
if (z.q == this->z_.q) return;
to the base_hamiltonian::init
base method and its deviations so that they no-op if the updated state does indeed match the internal state.
I disagree with that; Markov chains are naturally stateful and the sampler API should not hide it.
Markov chains are stateful, but Markov transitions are not. The API design separates Markov chains from Markov transitions, and the state in the Markov transition implementations is technically internal with no guarantees.
(Incidentally: why does it call hamiltonian_.samplep() before calling hamiltonian.init()? In principle a new point does not have its metric set until init() and sampling a momentum needs the metric. This doesn't really matter for Euclidean metrics, which are constant, but I'd expect it to break the softabs metric. I couldn't test that theory because I don't know how to expose a sampler with softabs in the services.)
It works for the same reason that removing this->hamiltonian_.init(this->z_, logger)
works; if the init_sample
argument is just the output of the previous transition
call then all of the relevant information is already in this->z_
so that calling this->hamiltonian_.init(this->z_, logger)
again is just an expensive no-op.
Technically this is a bug because init_sample
isn't guaranteed to be the output of the previous transition
call, it just happens to be everywhere we call transition
in the current API. That's also why it was never caught in empirical testing.
If we're making any changes we should swap those two lines as well.
Summary:
The number of log prob gradient evaluations per sample is one greater than the reported
n_leapfrog
for that sample. This does not need to be so; the gradient for the starting point was calculated in the previous iteration and has been saved by the sampler.Description:
Every transition in the sampler begins with initializing
this->z_
and sampling the momentum https://github.com/stan-dev/stan/blob/274a93c2d670edea4918bb80b1f278f55a9de654/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L82-L85 and ends with copying the selected sample tothis->z_
https://github.com/stan-dev/stan/blob/274a93c2d670edea4918bb80b1f278f55a9de654/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L201 Butstan::services::util::generate_transitions()
(the only functions that callssampler.transition()
) just passes around the sample it got fromsampler.transition()
. What's the point of copyingz_
back and forth and re-initializing the gradient whenthis->z_
already has the correct gradient from the previous iteration?(Incidentally: why does it call
hamiltonian_.sample_p()
before callinghamiltonian_.init()
? In principle a new point does not have its metric set untilinit()
and sampling a momentum needs the metric. This doesn't really matter for Euclidean metrics, which are constant, but I'd expect it to break thesoftabs
metric. I couldn't test that theory because I don't know how to expose a sampler withsoftabs
in the services.)Simply deleting the redundant
this->hamiltonian_.init(this->z_, logger);
causes no change in output and, for example, a 3d gaussian model runs a couple percent faster. That difference is barely distinguishable from noise and would be expected to be even less for a more complex model. As far as performance is concerned this doesn't matter much. I guess the current design was chosen because it leads to a (marginally simpler?) API that treats the sampler as stateless. I disagree with that; Markov chains are naturally stateful and the sampler API should not hide it. For example, split thetransition()
method into three phases.Then
generate_transitions()
does not need to callset_sample()
.Current Version:
v2.28.1