pymc-devs / nutpie

Python wrapper for nuts-rs
MIT License
110 stars 10 forks source link

Add `var_names` arg to PyMC compiled model #100

Open fonnesbeck opened 7 months ago

fonnesbeck commented 7 months ago

To accommodate the customization of variables to store in the trace, it would be helpful to have a similar argument for CompiledPyMCModel so that unwanted variables can be ignored by the trace.

btschroer commented 5 days ago

👍 Yes, that would be really nice to have. I have a hierarchical model with lots of groups.

@aseyboldt , @fonnesbeck: Is it at all possible that my nutpie sampling is slowed down, because my trace has to store lots of group-specific random variables? I only need samples of the population parameters.

Can you suggest how to circumvent this problem until var_names is available as argument?

aseyboldt commented 5 days ago

I agree that this would be nice to have (and it wouldn't be that hard to implement, it only needs some changes in the compile_pymc_model function. If someone wants to give it a go, I'd be glad to help).

If the model you are looking at is somewhat like the one you posted in the other thread, I'd be surprised if storing the trace is an issue though.

The simplest thing to get it faster is probably to switch to float32 (set the env variable PYTENSOR_FLAGS=floatX=float32). You could also give running it on the gpu a go, if the dataset is large that might help a lot.

And then, I'd double check the parametrization, and make sure your predictors are not too correlated. An easy thing to check to see if that can help is to have a look at the "gradients/draw". If that is large (say > 30 or 15 or so), that means that there is probably quite some room for improvement. This number is pretty much proportional to the runtime if all other things are equal. So if you can get it from 100 to 10, that's a 10x speedup.

btschroer commented 5 days ago

Thanks @aseyboldt for the suggestion, that's really helpful. I also looked into compile_pymc_model but it was essentially just calling functions "compile to numba" or "compile to jax", I am assuming the change would need to be done there, right? One point that is not entirely clear to me is to why we need to exclude variables at "compilation time" I thought we'd still want to sample e.g. group specific random variables, but then exclude them when writing the trace, i.e. exclusion would happen at "runtime"?

I checked the sampling stats, cf. output below, but I wasn't able to find the "number of gradient evaluations", is it correct to assume that this corresponds to 'n_steps'?

sample_stats
<xarray.Dataset> Size: 500kB
Dimensions:               (chain: 6, draw: 1000)
Coordinates:
  * chain                 (chain) int64 48B 0 1 2 3 4 5
  * draw                  (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
    depth                 (chain, draw) uint64 48kB 4 4 4 4 4 4 ... 4 4 4 4 4 4
    maxdepth_reached      (chain, draw) bool 6kB False False ... False False
    index_in_trajectory   (chain, draw) int64 48kB -8 8 12 -3 7 ... -6 4 5 3 11
    logp                  (chain, draw) float64 48kB 1.633e+04 ... 1.631e+04
    energy                (chain, draw) float64 48kB -1.617e+04 ... -1.614e+04
    diverging             (chain, draw) bool 6kB False False ... False False
    energy_error          (chain, draw) float64 48kB -0.4984 0.3992 ... -0.6857
    step_size             (chain, draw) float64 48kB 0.3677 0.3677 ... 0.354
    step_size_bar         (chain, draw) float64 48kB 0.3677 0.3677 ... 0.354
    mean_tree_accept      (chain, draw) float64 48kB 0.989 0.7761 ... 0.7476
    mean_tree_accept_sym  (chain, draw) float64 48kB 0.7821 0.8038 ... 0.8219
    n_steps               (chain, draw) uint64 48kB 15 15 15 15 ... 15 15 15 15
Attributes:
    created_at:     2024-10-22T11:45:06.969615+00:00
    arviz_version:  0.20.0