Open michaelosthege opened 2 years ago
That's a good question! I haven't doublechecked, but I think you get sum(2 ** tree_depth
) logp and grad(logp) evaluations. This is an approximation, I think because:
max(2 ** (tree_depth - 1), n_steps)
It wouldn't be terrible to add a dedicated counter, if junpeng doesn't have a better answer than mine.
I'm not sure how big a problem it is, but the other thing that I had mentioned was that these estimates seem to be algorithm/step-specific. So if my model has discrete and continuous parameters and a compound step is used or if I am using one of the increasing number of algorithms available via blackjax, etc., then the estimates based on NUTS don't necessarily seem relevant.
Yes, it's more complicated than it sounds, because with CompoundStep there are multiple levels from which we must aggregate the information.
More background in #4602
With, e.g., blackjax, don't we already lose the instrumentation inside the actual sampler (other than what is provided?) Seems like if the pymc sampler implemented a gradient or logp counter and it was useful, other samplers might follow suit.
IMO we should treat sampler stats in a similar way how we treat posterior draws: pre-registration, shape, dtypes, names, coords.. And storing them during MCMC should not be technically different from storing draws.
But yes, samplers that are implemented in Aesara/JAX and prioritize speed over observability are a problem.
This question came up during office hours:
Which sampler stat should be used to answer how often a model was evaluated during MCMC?
@junpenglao @ColCarroll
For example, these variables are tracked by NUTS:
If we're not currently tracking the information of how often the model was evaluated, we should definitely add that stat!