keraJLi / rejax

Apache License 2.0
150 stars 7 forks source link

Add Atari support #11

Open pseudo-rnd-thoughts opened 4 months ago

pseudo-rnd-thoughts commented 4 months ago

This is purely an idea if you wish and I understand how much work it is to get this working. Thankfully, https://github.com/mttga/purejaxql has shown this to be possible achieving 200 million frames in 1 hour with EnvPool XLA support. Therefore, would be cool to do hyperparameter search, etc over Atari environments

keraJLi commented 4 months ago

Hi there, and thanks for the suggestion! Atari would be great, and I aim to get benchmarks by October. I'll keep you updated on the progress 🤗

keraJLi commented 2 months ago

I've added preliminary Envpool support on the envpool_compat branch (see train_envpool.py). As far as I can see, you can already run any Rejax algorithm on any Envpool environment, and initial experiments on Pong suggest that training is successful.

Unfortunately, the XLA interface of Envpool is limited, as it does not support vmap (which makes sense looking at its basic architecture). Even creating a separate envpool for each algorithm instance (or training seed) that is vmapped fails. It might be possible to create a single envpool and take a lot of care that vmapped instances only access specific environments, but this is out of my scope for now. This means that training in Envpool environments requires not only environment wrapping but also some modification of the training algorithm, which is implemented in rejax.compat.envpool2gymnax. Still, Rejax might be faster than other implementations even on one seed, so I'll continue adding support for Envpool.