pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 453 forks source link

Support for torch.distributed.scatter in PyTorch XLA #4940

Open A9isha opened 1 year ago

A9isha commented 1 year ago

🚀 Feature

Requesting PyTorch/XLA to support torch.distributed.scatter

Motivation

PyTorch/XLA today does not support torch.distributed.scatter. trlx (for RLHF) uses it here: https://github.com/CarperAI/trlx/blob/main/trlx/trainer/accelerate_ppo_trainer.py#L325

JackCaoG commented 1 year ago

Thanks for the request, could you also share your current workaround?

A9isha commented 1 year ago

Thanks Jack for the response. My current workaround is to use reduce_scatter instead.

https://github.com/A9isha/trlx/blob/anisha-test-trlx/trlx/trainer/accelerate_ppo_trainer.py#L348-L356