Open A9isha opened 1 year ago
Requesting PyTorch/XLA to support torch.distributed.scatter
torch.distributed.scatter
PyTorch/XLA today does not support torch.distributed.scatter. trlx (for RLHF) uses it here: https://github.com/CarperAI/trlx/blob/main/trlx/trainer/accelerate_ppo_trainer.py#L325
Thanks for the request, could you also share your current workaround?
Thanks Jack for the response. My current workaround is to use reduce_scatter instead.
reduce_scatter
https://github.com/A9isha/trlx/blob/anisha-test-trlx/trlx/trainer/accelerate_ppo_trainer.py#L348-L356
🚀 Feature
Requesting PyTorch/XLA to support
torch.distributed.scatter
Motivation
PyTorch/XLA today does not support
torch.distributed.scatter
. trlx (for RLHF) uses it here: https://github.com/CarperAI/trlx/blob/main/trlx/trainer/accelerate_ppo_trainer.py#L325