Closed zwelitunyiswa closed 8 months ago
Bayeux is JAX only! There's a chance you could use jx2tf to convert the function, but I've never tried this.
Note that at that point, you'd be doing something like Python -> PyTensor -> JAX -> TensorFlow.
Ah! Ok. I get it. I might try that just to see if I can get it to work! Thank you.
Good luck! Let me know how it goes.
I am on a M1 Ultra with the silicon GPU. Given that Tensorflow has Apple Silicon GPU support. Can I leverage that via Bayeux? I installed Tensorflow + the TensorFlow Metal library (and Tensorflow sees the GPU) but I cannot figure out how to tell Bayeux to tell Tensorflow-Probability to use Tensorflow instead of JAX, and thus the GPU.