pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.68k stars 2.01k forks source link

`sample_blackjax_nuts` fails with default settings on single GPU #6176

Closed martiningram closed 2 years ago

martiningram commented 2 years ago

Description of your problem

Hi all,

I'm trying to run a PyMC runtime comparison with blackjax on the GPU. This fails with the default settings, producing the following error:

PyMC BlackJAX GPU
/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/pymc/sampling_jax.py:37: UserWarning: This module is experimental.
  warnings.warn("This module is experimental.")
Compiling...
Compilation time =  0:00:01.213512
Sampling...
Traceback (most recent call last):
  File "/home/martin/projects/pymc_vs_stan_revamp/mcmc_runtime_comparison/fit_pymc_blackjax.py", line 28, in <module>
    hierarchical_trace = pymc.sampling_jax.sample_blackjax_nuts(random_seed=seed)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/pymc/sampling_jax.py", line 355, in sample_blackjax_nuts
    states, _ = map_fn(get_posterior_samples)(keys, init_params)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/_src/api.py", line 2202, in cache_miss
    out_tree, out_flat = f_pmapped_(*args, **kwargs)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/_src/api.py", line 2075, in pmap_f
    out = pxla.xla_pmap(
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/core.py", line 2062, in bind
    return map_bind(self, fun, *args, **params)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/core.py", line 2094, in map_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/core.py", line 2065, in process
    return trace.process_map(self, fun, tracers, params)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/core.py", line 701, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 948, in xla_pmap_impl
    compiled_fun, fingerprint = parallel_callable(
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/linear_util.py", line 306, in memoized_fun
    ans = call(fun, *args)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1198, in parallel_callable
    pmap_executable = pmap_computation.compile()
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/_src/profiler.py", line 313, in wrapper
    return func(*args, **kwargs)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1481, in compile
    self._executable = PmapExecutable.from_hlo(self._hlo, **self.compile_args)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1510, in from_hlo
    raise ValueError(msg.format(shards.num_global_shards,
jax._src.traceback_util.UnfilteredStackTrace: ValueError: compiling computation that requires 4 logical devices, but only 1 XLA devices are available (num_replicas=4, num_partitions=1)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/martin/projects/pymc_vs_stan_revamp/mcmc_runtime_comparison/fit_pymc_blackjax.py", line 28, in <module>
    hierarchical_trace = pymc.sampling_jax.sample_blackjax_nuts(random_seed=seed)
  File "/home/martin/miniconda3/envs/pymc_comparison/lib/python3.10/site-packages/pymc/sampling_jax.py", line 355, in sample_blackjax_nuts
    states, _ = map_fn(get_posterior_samples)(keys, init_params)
ValueError: compiling computation that requires 4 logical devices, but only 1 XLA devices are available (num_replicas=4, num_partitions=1)

The error makes sense to me: it looks like it's trying to run a pmap but there is only a single GPU. In numpyro, I've found that the vectorised chain_method works well, which I think uses vmap to run the multiple chains on a single GPU. In any case, I'd appreciate a pointer to what I can do to run blackjax efficiently on a single GPU. Thanks for your help!

Versions and main components

martiningram commented 2 years ago

The solution to this problem is to use chain_method='vectorized' rather than the default of chain_method='parallel', as pointed out to me by @aloctavodia (thank you!). I'm closing this issue, but it might be worth adding some logic to automatically do this when the number of devices is insufficient for parallel execution...?