jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
162 stars 6 forks source link

Can Bayeux help leverage TFP for Apple Silicon GPU? #30

Closed zwelitunyiswa closed 8 months ago

zwelitunyiswa commented 8 months ago

I am on a M1 Ultra with the silicon GPU. Given that Tensorflow has Apple Silicon GPU support. Can I leverage that via Bayeux? I installed Tensorflow + the TensorFlow Metal library (and Tensorflow sees the GPU) but I cannot figure out how to tell Bayeux to tell Tensorflow-Probability to use Tensorflow instead of JAX, and thus the GPU.

Screenshot 2024-02-21 at 12 51 55 PM

ColCarroll commented 8 months ago

Bayeux is JAX only! There's a chance you could use jx2tf to convert the function, but I've never tried this.

Note that at that point, you'd be doing something like Python -> PyTensor -> JAX -> TensorFlow.

zwelitunyiswa commented 8 months ago

Ah! Ok. I get it. I might try that just to see if I can get it to work! Thank you.

ColCarroll commented 8 months ago

Good luck! Let me know how it goes.