pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.75k stars 2.03k forks source link

Question/Feature request: Figuring out how often the model/gradient was evaluated by NUTS #5809

Open michaelosthege opened 2 years ago

michaelosthege commented 2 years ago

This question came up during office hours:

Which sampler stat should be used to answer how often a model was evaluated during MCMC?

@junpenglao @ColCarroll

For example, these variables are tracked by NUTS:


If we're not currently tracking the information of how often the model was evaluated, we should definitely add that stat!

ColCarroll commented 2 years ago

That's a good question! I haven't doublechecked, but I think you get sum(2 ** tree_depth) logp and grad(logp) evaluations. This is an approximation, I think because:

  1. (most of the inaccuracy): PyMC will short-circuit when the U-turn criteria (or divergence) hits, so it could be as low as max(2 ** (tree_depth - 1), n_steps)
  2. (a little inaccuracy, maybe): There might be an extra logp evaluated for the metropolis acceptance step at the end -- I'm not sure if that number is saved from the leapfrog integrator.

It wouldn't be terrible to add a dedicated counter, if junpeng doesn't have a better answer than mine.

cluhmann commented 2 years ago

I'm not sure how big a problem it is, but the other thing that I had mentioned was that these estimates seem to be algorithm/step-specific. So if my model has discrete and continuous parameters and a compound step is used or if I am using one of the increasing number of algorithms available via blackjax, etc., then the estimates based on NUTS don't necessarily seem relevant.

michaelosthege commented 2 years ago

Yes, it's more complicated than it sounds, because with CompoundStep there are multiple levels from which we must aggregate the information.

More background in #4602

ColCarroll commented 2 years ago

With, e.g., blackjax, don't we already lose the instrumentation inside the actual sampler (other than what is provided?) Seems like if the pymc sampler implemented a gradient or logp counter and it was useful, other samplers might follow suit.

michaelosthege commented 2 years ago

IMO we should treat sampler stats in a similar way how we treat posterior draws: pre-registration, shape, dtypes, names, coords.. And storing them during MCMC should not be technically different from storing draws.

But yes, samplers that are implemented in Aesara/JAX and prioritize speed over observability are a problem.