pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

Speeding up NUTS and MCMC in tests #1740

Closed Jacob-Stevens-Haas closed 2 months ago

Jacob-Stevens-Haas commented 5 months ago

In a package that builds regression problems with a variety of solvers, we recently added numpyro and a regularized horseshoe prior. The tests however, take longer than any other method, and I'm trying to speed them up. Even a small test (num_warmup=1, num_samples=4, data shape =(10, 2)) takes around 5 seconds - substantially longer than our tests for discrete optimizations with gurobipy. Fixing the geometry would probably help in this specific case, but in general, are there any kwargs that will let NUTS/MCMC find a very quick (and bad) solution? Like setting max_iter=1 in a gradient descent algorithm. I tried setting max_tree_depth to 1 and target_accept_prob to .1, but that didn't change timing appreciably. This would be useful in tests like "make sure that we can pickle our models after fitting" and "make sure that a fit model has certain attributes".

In profiling, the test spends about 41% the time initializing,: 10% hmc:186:initialize kernel 31% util:303:find_valid_initial_params and 44% util:266:fori_collect (I assume actual execution)

Totally understand if it's not available, though! - it's a great package as is.

```python # %% import numpy as np import jax.numpy as jnp from jax import random import numpyro from numpyro.diagnostics import summary from numpyro.distributions import HalfCauchy, InverseGamma, Normal from numpyro.infer import MCMC, NUTS from pysindy.optimizers.sbr import _sample_reg_horseshoe def model(x, y): # beta = reg_horseshoe_prior(1e-1, 5, 3, (1, x.shape[1])) beta = numpyro.sample("beta", Normal()) preds = jnp.dot(x, beta.T) error = numpyro.sample("obs", Normal(preds, 1e-1), obs=y) def reg_horseshoe_prior( global_sparsity: float, degrees_of_freedom: float, slab_var: float, shape: tuple[int, ...] ): tau = numpyro.sample("tau", HalfCauchy(global_sparsity)) c_sq = numpyro.sample( "c_sq", InverseGamma( degrees_of_freedom / 2, degrees_of_freedom / 2 * slab_var**2 ), ) lamb = numpyro.sample("lambda", HalfCauchy(1.0), sample_shape=shape) lamb_squiggle = jnp.sqrt(c_sq) * lamb / jnp.sqrt(c_sq + tau**2 * lamb**2) beta = numpyro.sample( "beta", Normal(jnp.zeros_like(lamb_squiggle), jnp.sqrt(lamb_squiggle**2 * tau**2)), ) return beta # %% x = np.random.normal(size=(10, 2)) y = x[:, :1] kernel = NUTS(model) mcmc = MCMC(kernel, num_warmup=1, num_samples=4) rng_key = random.PRNGKey(0) mcmc.run(rng_key, x=x, y=y) summary_dict = summary(mcmc.get_samples(), group_by_chain=False) ``` Example code:
martinjankowiak commented 5 months ago

don't think there's much that you can do since presumably much of that time is going into jax compilation

fehiepsi commented 2 months ago

Closed because this is more specific to jax. Please checkout https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html. As far as I know, it supports TPU and GPU for now and will support CPU in near future.

Jacob-Stevens-Haas commented 2 months ago

Thanks! I'll check it out