Open mrkulk opened 3 years ago
Hey @mrkulk so I had started working on DDP Sharded+Ray integration in this PR https://github.com/ray-project/ray_lightning/pull/16, but this work is outdated since Pytorch Lightning had some major changes with distributed accelerators/plugins in 1.2.
With the latest Pytorch Lightning update I think we would have to create a RayShardedPlugin
that subclasses the RayPlugin
similar to this https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/sharded_spawn.py#L28.
If you are interested, please feel free take a stab at a PR.
Also, do you mind sharing some more about your use case? Are your models too big to fit in a single GPU? What is your model size and GPU memory? Thanks.
re ddp_sharded -- great to know and I will check it out
regarding use case -- yes the models are too big to fit on a single GPU and also want larger batch sizes. We typically use V or A 100s but that can change in the future. The model size is scaling in size every month as we make progress. Since deepspeed can handle billions of parameters (at stage 3 optimization), it would be great to have that or work on something along the same lines that eventually goes beyond
Fairscale integration has now been merged (#42). Check it out for low memory distributed training!
Is ray_lightning currently compatible with the deepspeed accelerator in PTL?