Open 51616 opened 1 year ago
Hey @51616, we have a working version of PPO that leverages the XLA interface — see docs here and tweet here. It is about 3x faster than openai/baselines
' PPO when tested out at the scale of 57 Atari games using the default Atari parameters (e.g., num_envs=8
).
I have a preliminary profile with a wandb report here, which shows that XLA interface combined with inferencing improves speed even further.
Motivation
If I understand correctly, the speed up of envpool comes from c++ implementation as supposed to python. So, I wonder if the XLA interface will provide anymore speed up when jitted by
jax
, which can then be further fused with NN policy usingjax.lax.scan
. It would be nice to have a benchmark for the XLA version of the environments. I suppose different hardware would yield different results, which is useful for deciding when to jit or not to jit the environment.Solution
Benchmark of non-jitted XLA vs jitted XLA vs the default c++ implementation
Checklist