Open sash-a opened 2 months ago
Is there a way for us to reproduce this error?
Unfortunately I don't have a minimal example, I've only encountered it in a large code base. Any idea what the error could be pointing me to? The traceback just ends where I call into a pmapped function
Description
Hi there, getting a strange and uninformative error when running on TPU that I don't get when running locally, was hoping someone here could explain the error :thinking:
Any help with what this error could mean would be appreciated. I've restarted the TPU and tried different JAX versions, but all seem to give the same error
System info (python version, jaxlib version, accelerator, etc.)
TPU v4-8 python 3.10 JAX 0.4.24, 0.4.25 and 0.4.26