Open fonnesbeck opened 7 months ago
👍 Yes, that would be really nice to have. I have a hierarchical model with lots of groups.
@aseyboldt , @fonnesbeck: Is it at all possible that my nutpie
sampling is slowed down, because my trace has to store lots of group-specific random variables? I only need samples of the population parameters.
Can you suggest how to circumvent this problem until var_names
is available as argument?
I agree that this would be nice to have (and it wouldn't be that hard to implement, it only needs some changes in the compile_pymc_model
function. If someone wants to give it a go, I'd be glad to help).
If the model you are looking at is somewhat like the one you posted in the other thread, I'd be surprised if storing the trace is an issue though.
The simplest thing to get it faster is probably to switch to float32 (set the env variable PYTENSOR_FLAGS=floatX=float32
).
You could also give running it on the gpu a go, if the dataset is large that might help a lot.
And then, I'd double check the parametrization, and make sure your predictors are not too correlated. An easy thing to check to see if that can help is to have a look at the "gradients/draw". If that is large (say > 30 or 15 or so), that means that there is probably quite some room for improvement. This number is pretty much proportional to the runtime if all other things are equal. So if you can get it from 100 to 10, that's a 10x speedup.
Thanks @aseyboldt for the suggestion, that's really helpful. I also looked into compile_pymc_model
but it was essentially just calling functions "compile to numba" or "compile to jax", I am assuming the change would need to be done there, right? One point that is not entirely clear to me is to why we need to exclude variables at "compilation time" I thought we'd still want to sample e.g. group specific random variables, but then exclude them when writing the trace, i.e. exclusion would happen at "runtime"?
I checked the sampling stats, cf. output below, but I wasn't able to find the "number of gradient evaluations", is it correct to assume that this corresponds to 'n_steps'?
sample_stats
<xarray.Dataset> Size: 500kB
Dimensions: (chain: 6, draw: 1000)
Coordinates:
* chain (chain) int64 48B 0 1 2 3 4 5
* draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
depth (chain, draw) uint64 48kB 4 4 4 4 4 4 ... 4 4 4 4 4 4
maxdepth_reached (chain, draw) bool 6kB False False ... False False
index_in_trajectory (chain, draw) int64 48kB -8 8 12 -3 7 ... -6 4 5 3 11
logp (chain, draw) float64 48kB 1.633e+04 ... 1.631e+04
energy (chain, draw) float64 48kB -1.617e+04 ... -1.614e+04
diverging (chain, draw) bool 6kB False False ... False False
energy_error (chain, draw) float64 48kB -0.4984 0.3992 ... -0.6857
step_size (chain, draw) float64 48kB 0.3677 0.3677 ... 0.354
step_size_bar (chain, draw) float64 48kB 0.3677 0.3677 ... 0.354
mean_tree_accept (chain, draw) float64 48kB 0.989 0.7761 ... 0.7476
mean_tree_accept_sym (chain, draw) float64 48kB 0.7821 0.8038 ... 0.8219
n_steps (chain, draw) uint64 48kB 15 15 15 15 ... 15 15 15 15
Attributes:
created_at: 2024-10-22T11:45:06.969615+00:00
arviz_version: 0.20.0
To accommodate the customization of variables to store in the trace, it would be helpful to have a similar argument for
CompiledPyMCModel
so that unwanted variables can be ignored by the trace.