Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
463 stars 172 forks source link

MPIVecEnv #45

Open sgillen opened 2 years ago

sgillen commented 2 years ago

Hello,

I was trying to find a way to make the ARS implementation I was working on in this pr faster. My first thought was a drop in replacement for SubprocVecEnv that uses mpi4py instead. I implemented a first pass here. It is quick and dirty, but still a working proof of concept to see if there is any performance to be gained here.

I am seeing modest speedups in rollout collection time. For pendulum-v0 with 10 environments I am finding it 4-5x faster than Dummy and Subproc. For HumanoidBulletEnv-v0 with 10 environments I am finding it 8x faster than Dummy and 2x Faster than SubProc. It might be possible to squeeze more performance out of it but probably this is 80% of what can be achieved using this approach.

This is just for rollout collections, any actual speedup to algorithms using this vec env are going to be smaller, but for on policy algorithms probably still significant.

I wanted to ask if this or something like it had been considered, IIRC mpi4py was a big headache to support, but perhaps by confining that dependency to contrib/ most of the headache will disappear. Can also look at for example torch distributed, but I think that will cause a similar number of headaches for most likely less speed.

This is another thing I would interested in contributing (over the following weeks ...). But again, only if there is interest.

araffin commented 2 years ago

Hello, it looks interesting, do you have a minimal code example on how it works? the synchronization is done after each step?

In fact, I have the plan to have a complete PPO MPI version at some point in SB3 contrib (#11 ) where each core also compute gradients.

sgillen commented 2 years ago

Yes, the examples I made for testing live here. The synchronization happens at each steps, each call to G/gather is blocking.

I figure the proposed vec env could be used by any existing sb3 algorithms for a modest speedup, and maybe by future "MPI aware" algorithms for bigger gains.

araffin commented 2 years ago

Yes, the examples I made for testing live here.

thanks =) but I'm afraid this won't work with the current SB3 implementation...

sgillen commented 2 years ago

OK, I'm not seeing why that is though?