Open alexlatif opened 4 days ago
Hi @alexlatif
I tested the provided code with JAX-metal
on a Macbook Pro M1 Pro
. While there were no hanging issues, model.init
and model.apply
took longer than the CPU version. Please find the attached screenshots below:
Thank you.
You're correct in that eventually it does run. However on Macbook Air M2 Sonoma 14.4.1 this took ~5 mins. Any insight on why it's so much slower on metal?
Description
To reproduce the working state uncomment the device update to cpu
System info (python version, jaxlib version, accelerator, etc.)