Closed martiningram closed 2 years ago
The solution to this problem is to use chain_method='vectorized'
rather than the default of chain_method='parallel'
, as pointed out to me by @aloctavodia (thank you!). I'm closing this issue, but it might be worth adding some logic to automatically do this when the number of devices is insufficient for parallel execution...?
Description of your problem
Hi all,
I'm trying to run a PyMC runtime comparison with
blackjax
on the GPU. This fails with the default settings, producing the following error:The error makes sense to me: it looks like it's trying to run a
pmap
but there is only a single GPU. Innumpyro
, I've found that thevectorised
chain_method works well, which I think usesvmap
to run the multiple chains on a single GPU. In any case, I'd appreciate a pointer to what I can do to runblackjax
efficiently on a single GPU. Thanks for your help!Versions and main components