araffin / sbx

SBX: Stable Baselines Jax (SB3 + Jax)
MIT License
328 stars 32 forks source link

Mujoco XLA - MJX Integration #23

Closed matinmoezzi closed 6 months ago

matinmoezzi commented 9 months ago

As the biggest bottleneck of the training performance of SB3 is the environment, I am considering integrating SB3 with Mujoco XLA which is Mujoco written in Jax. Would this integration increase the performance? Currently, Mujoco XLA is released with huge performance improvement with Brax, including RL algorithms in JAX. Is SBX fully written in JAX?

araffin commented 9 months ago

Hello,

As the biggest bottleneck of the training performance of SB3 is the environment

I would actually disagree with this statement. The main reason SBX is much faster than SB3 PyTorch is because the bottleneck was the gradient update.

Would this integration increase the performance?

It might but first you need to be sure where is the bottleneck and that you have optimized the parameters of SBX because considering faster env.

Is SBX fully written in JAX?

It is not, it still uses numpy/pytorch for the rollout/replay buffer. The gradient updates are.