sail-sg / envpool

C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments.
https://envpool.readthedocs.io
Apache License 2.0
1.08k stars 99 forks source link

[Feature Request] XLA interface speed comparison #213

Open 51616 opened 1 year ago

51616 commented 1 year ago

Motivation

If I understand correctly, the speed up of envpool comes from c++ implementation as supposed to python. So, I wonder if the XLA interface will provide anymore speed up when jitted by jax, which can then be further fused with NN policy using jax.lax.scan. It would be nice to have a benchmark for the XLA version of the environments. I suppose different hardware would yield different results, which is useful for deciding when to jit or not to jit the environment.

Solution

Benchmark of non-jitted XLA vs jitted XLA vs the default c++ implementation

Checklist

vwxyzjn commented 1 year ago

Hey @51616, we have a working version of PPO that leverages the XLA interface — see docs here and tweet here. It is about 3x faster than openai/baselines' PPO when tested out at the scale of 57 Atari games using the default Atari parameters (e.g., num_envs=8).

I have a preliminary profile with a wandb report here, which shows that XLA interface combined with inferencing improves speed even further.

image