Open frasermince opened 2 years ago
@mavenlin could you please take a look?
There's currently no open API for copying between cpu and tpu as far as I know. Therefore not possible to support jitting on tpu for now.
I see. So would it be possible to use the xla functions on the CPU as part of a jitted function and then move arrays over to the tpu afterwards? I've tried setting the default device to cpu while calling recv
but I get the same error.
I see. So would it be possible to use the xla functions on the CPU as part of a jitted function and then move arrays over to the tpu afterwards? I've tried setting the default device to cpu while calling
recv
but I get the same error.
Do you mean that even jitting for CPU fails on a tpu vm? Before we introduced GPU support, jitting for CPU works fine on a GPU machine.
I encountered the same problem.
Describe the bug
Upon running
recv
from the xla functions with jax tpus enabled I get the error: "NotImplementedError: MLIR translation rule for primitive 'AtariGymEnvPool_140079244014352_recv' not found for platform tpu". This is happening in a google tpu vm with python version 3.9.13, jax version 0.3.13, and envpool version 0.6.2.To Reproduce
Expected behavior
I expect recv to work correctly on a tpu in a jitted jax program
Screenshots
System info
Describe the characteristic of your environment:
Checklist