Open PaulScemama opened 11 months ago
This is not my experience, what is your environment?
Also, important note re benchmarking JAX: https://jax.readthedocs.io/en/latest/async_dispatch.html
@junpenglao I will check back on this later this weekend -- possible that it is an environment problem. I'll get back to you then.
This is replicable on Colab, so I don't think it's an environment issue.
Output for CPU:
Jax sees these devices: [CpuDevice(id=0)] Starting to run nuts for 500 steps NUTS Call took 0.12702747980753581 minutes
Output for GPU:
Jax sees these devices: [cuda(id=0)] Starting to run nuts for 500 steps NUTS Call took 0.7922836542129517 minutes
I think this is more or less expected behavior though when the problem is rather small, and doesn't include operations GPUs are particularly good at. There was a similar discussion for NumPyro here, with the takeaway being that Jax is particularly efficient on CPU and GPU acceleration only makes sense for certain problems.
Note that NUTS is control-flow heavy which makes its hard to run fast on a GPU.
See the CHEES algorithm, implemented in BlackJax, for a NUTS-like sampler that avoids this problem.
Describe the issue as clearly as possible:
On a trivial example (that of quickstart.md) there appears to be a weird bug I'm experiencing with the NUTS sampler using a GPU.
When I run the script (which I copy below) with a GPU for 200 steps I get
When I run the script with a GPU for 300 steps I get
When I run the script with GPU for 500 steps I get
When I run the script on CPU with 1000 steps I get
Steps/code to reproduce the bug:
Expected result:
Error message:
No response
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
No response