google-deepmind / meltingpot

A suite of test scenarios for multi-agent reinforcement learning.
Apache License 2.0
577 stars 116 forks source link

Jax-related #192

Closed alexunderch closed 9 months ago

alexunderch commented 9 months ago

Are there any ways to natively wrap an environment to be suitable for running a rl algorithm with jax (make the environment stateless), or is it planned to add such support?

Thank you!

duenez commented 9 months ago

Hello,

No, it would be very challenging to make the environments stateless and implemented exclusively in JAX. That said, there's no reason why you couldn't use JAX for the agent side, and that's what we do internally. We use something similar to the A3C setup where there are processes stepping the environment in CPU only, and the agent is doing inference and learning on the accelerator (GPU/TPU).

alexunderch commented 9 months ago

By the looks of things, it the only feasible variant right now. Thank you for the response. Closing the issue.