sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.07k stars 99 forks source link

[Feature Request] Running XLA functions on TPU #172

Open frasermince opened 2 years ago

frasermince commented 2 years ago

Describe the bug

Upon running recv from the xla functions with jax tpus enabled I get the error: "NotImplementedError: MLIR translation rule for primitive 'AtariGymEnvPool_140079244014352_recv' not found for platform tpu". This is happening in a google tpu vm with python version 3.9.13, jax version 0.3.13, and envpool version 0.6.2.

To Reproduce

import jax
import envpool

env = envpool.make("Pong-v5", env_type="gym", num_envs=64, batch_size=16)
handle, recv, send, _ = env.xla()

handle2, data = recv(handle)

Expected behavior

I expect recv to work correctly on a tpu in a jitted jax program

Screenshots

image

System info

Describe the characteristic of your environment:

Checklist

Trinkle23897 commented 2 years ago

@mavenlin could you please take a look?

mavenlin commented 2 years ago

There's currently no open API for copying between cpu and tpu as far as I know. Therefore not possible to support jitting on tpu for now.

frasermince commented 2 years ago

I see. So would it be possible to use the xla functions on the CPU as part of a jitted function and then move arrays over to the tpu afterwards? I've tried setting the default device to cpu while calling recv but I get the same error.

mavenlin commented 2 years ago

I see. So would it be possible to use the xla functions on the CPU as part of a jitted function and then move arrays over to the tpu afterwards? I've tried setting the default device to cpu while calling recv but I get the same error.

Do you mean that even jitting for CPU fails on a tpu vm? Before we introduced GPU support, jitting for CPU works fine on a GPU machine.

qlan3 commented 1 year ago

I encountered the same problem.