Open Idate96 opened 1 year ago
Ahh that's a good idea. It was not on the roadmap, but I would imagine doing something like it would not be that difficult.
Do you think it would just largely involve pmean
-ing the grad updates?
also interested in this. +1
We adapted PureJaxRL ppo+rnn implementation to the multi-gpu with pmap in XLand-MiniGrid and it scales well (almost linear from 1 up to 8 A100 gpus)!
Awesome! I took a quick look -- I see that the env steps per second scales linearly; however, do you know how performance scales with time?
@luchris429 It just takes a bit more to compile in general (If I correctly understood time as number of total timesteps). I didn't notice any other performance dips for the 10 minute and ~8 hour runs. GPU utilization 100%, OOM does not happen.
I was wondering if there are any plans to release multi-gpu training code? Naively pmapping and using DDPPO does not seem to scale well, as the gpus remain idle while syncing the gradients.