luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
738 stars 62 forks source link

Multi GPU support #9

Open Idate96 opened 1 year ago

Idate96 commented 1 year ago

I was wondering if there are any plans to release multi-gpu training code? Naively pmapping and using DDPPO does not seem to scale well, as the gpus remain idle while syncing the gradients.

luchris429 commented 1 year ago

Ahh that's a good idea. It was not on the roadmap, but I would imagine doing something like it would not be that difficult.

Do you think it would just largely involve pmean-ing the grad updates?

ugurbolat commented 1 year ago

also interested in this. +1

Howuhh commented 11 months ago

We adapted PureJaxRL ppo+rnn implementation to the multi-gpu with pmap in XLand-MiniGrid and it scales well (almost linear from 1 up to 8 A100 gpus)!

luchris429 commented 11 months ago

Awesome! I took a quick look -- I see that the env steps per second scales linearly; however, do you know how performance scales with time?

Howuhh commented 11 months ago

@luchris429 It just takes a bit more to compile in general (If I correctly understood time as number of total timesteps). I didn't notice any other performance dips for the 10 minute and ~8 hour runs. GPU utilization 100%, OOM does not happen.