stan-dev / stan

Stan development repository. The master branch contains the current release. The develop branch contains the latest stable development. See the Developer Process Wiki for details.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
2.61k stars 369 forks source link

One redundant gradient evaluation on every iteration #3077

Open nhuurre opened 3 years ago

nhuurre commented 3 years ago

Summary:

The number of log prob gradient evaluations per sample is one greater than the reported n_leapfrog for that sample. This does not need to be so; the gradient for the starting point was calculated in the previous iteration and has been saved by the sampler.

Description:

Every transition in the sampler begins with initializing this->z_ and sampling the momentum https://github.com/stan-dev/stan/blob/274a93c2d670edea4918bb80b1f278f55a9de654/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L82-L85 and ends with copying the selected sample to this->z_ https://github.com/stan-dev/stan/blob/274a93c2d670edea4918bb80b1f278f55a9de654/src/stan/mcmc/hmc/nuts/base_nuts.hpp#L201 But stan::services::util::generate_transitions() (the only functions that calls sampler.transition()) just passes around the sample it got from sampler.transition(). What's the point of copying z_ back and forth and re-initializing the gradient when this->z_ already has the correct gradient from the previous iteration?

(Incidentally: why does it call hamiltonian_.sample_p() before calling hamiltonian_.init()? In principle a new point does not have its metric set until init() and sampling a momentum needs the metric. This doesn't really matter for Euclidean metrics, which are constant, but I'd expect it to break the softabs metric. I couldn't test that theory because I don't know how to expose a sampler with softabs in the services.)

Simply deleting the redundant this->hamiltonian_.init(this->z_, logger); causes no change in output and, for example, a 3d gaussian model runs a couple percent faster. That difference is barely distinguishable from noise and would be expected to be even less for a more complex model. As far as performance is concerned this doesn't matter much. I guess the current design was chosen because it leads to a (marginally simpler?) API that treats the sampler as stateless. I disagree with that; Markov chains are naturally stateful and the sampler API should not hide it. For example, split the transition() method into three phases.

// new stateful API
virtual void set_sample(sample&, logger&) = 0;
virtual void update_state(logger&) = 0;
virtual sample get_sample() = 0;

// old, stateless API implemented with the stateful API
sample transition(sample& s, logger& logger) {
  this->set_sample(s, logger);
  this->update_state(logger);
  return this->get_sample();
}

Then generate_transitions() does not need to call set_sample().

Current Version:

v2.28.1

bob-carpenter commented 3 years ago

Thanks, @nhuurre. I thought our code already did this! I imagine you're right about why it's coded this way for simplicity and to keep transitions stateless. Transitions should be stateless and only depend on the previous state, but it's also OK to store functions of that previous state. It's not OK to store things like moving averages and condition jumps on that.

Given that gradients are the bottleneck for the whole algorithm, and this requires (L + 1) rather than L, the speedup should depend on the average number of gradient evaluations L per iteration. In really simple models, the other bits of the code are more of a proportion of run time than gradients. So the optimal speedup would be in a problem with easy geometry (low L) and expensive gradients.

betanalpha commented 3 years ago

Some history: when the Markov chain Monte Carlo code was designed so many years ago we had no idea what scope of Markov transitions we might implement down the road. This included potential improvements or variants of the original No-U-Turn and Hybrid Monte Carlo samplers that had been implemented for continuous spaces but also potential samplers for discrete spaces. In order to make the code as flexible as possible for these potential samplers the Markov transitions were abstracted as stateless transformations, with the Markov chain Monte Carlo code working with only a generic sample object. That way Markov transitions could incorporate whatever auxiliary information they needed into their internal states. Moreover this allowed for composite samplers where different transitions were applied to one common states. This is also why the CmdStan argument structure is so flexible; it had to accommodate all of the possible samplers that might have been incorporated in what was then the near future.

With the promise of other Markov transitions never being realized this generic functionality hasn't been used. At some point the very generic code that called all of the transitions was organized into the existing services library that forces iteration with a single Markov transition such that the external Markov chain Monte Carlo sample and internal Markov transition state are always aligned.

Given the current state of the code removing the suggested this->hamiltonian_.init(this->z_, logger) call would be valid, but technically there's nothing prevention new API routes from calling transition in different ways that would cause the external sample and internal state to drift. We could change the transition signature, but that would require touching lots of other code and I think it would be more productive to consider a full redesign of the Markov chain Monte Carlo code around the current dynamic Hamiltonian Monte Carlo sampler instead of the full generality.

I think a safer and more precise change would be to add

if (z.q == this->z_.q) return;

to the base_hamiltonian::init base method and its deviations so that they no-op if the updated state does indeed match the internal state.

I disagree with that; Markov chains are naturally stateful and the sampler API should not hide it.

Markov chains are stateful, but Markov transitions are not. The API design separates Markov chains from Markov transitions, and the state in the Markov transition implementations is technically internal with no guarantees.

(Incidentally: why does it call hamiltonian_.samplep() before calling hamiltonian.init()? In principle a new point does not have its metric set until init() and sampling a momentum needs the metric. This doesn't really matter for Euclidean metrics, which are constant, but I'd expect it to break the softabs metric. I couldn't test that theory because I don't know how to expose a sampler with softabs in the services.)

It works for the same reason that removing this->hamiltonian_.init(this->z_, logger) works; if the init_sample argument is just the output of the previous transition call then all of the relevant information is already in this->z_ so that calling this->hamiltonian_.init(this->z_, logger) again is just an expensive no-op.

Technically this is a bug because init_sample isn't guaranteed to be the output of the previous transition call, it just happens to be everywhere we call transition in the current API. That's also why it was never caught in empirical testing.

If we're making any changes we should swap those two lines as well.