Open mattjj opened 2 years ago
Hi @mattjj
Looks like this issue has been resolved in later versions of JAX. I tried to reproduce this issue with the latest JAX version 0.4.26 on cloud VM with 4 T4 GPUs. But it works without any error now. Please find the below screenshot for reference.
Thank you.
As I run this code on my machine, there are errors like this
Could you give me some advice?
Originally posted by @caihao in https://github.com/google/jax/issues/1899#issuecomment-614026930