pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.87k stars 22.34k forks source link

Allow sharding for a tensor for FSDP-2 on arbitrary dim instead of 0 #137342

Open mayank31398 opened 1 week ago

mayank31398 commented 1 week ago

🚀 The feature, motivation and pitch

I created a tensor of shape (2, N) in a module wrap it with FSDP on 2 GPUs following are the shapes: GPU 0, 1 -> (1, N) GPU 2 to 7 -> (0, N)

this is a big limitation if the shape of tensors in dim=0 are not big enough

Is it possible to modify FSDP-2 like this:

if dim=0 is big enough: Shard(0) else: iterate to find the largest dim Shard(max dim)

Alternatives

No response

Additional context

No response

cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @kwen2501 @chauhang

awgu commented 1 week ago

Can you give an example of this use case for us to understand the motivation better?

Anyway, we are working on this https://github.com/pytorch/pytorch/pull/136221/, but it will take some time to mature.

  1. We need efficient copy kernels to support non-dim-0 sharding.
  2. We need additional copies compared to dim-0 because we cannot pre-pad sharded parameter and must un-pad the sharded gradient in order for them to have contiguous strides (as required by some kernels, e.g. Apex fused Adam).

Concretely, dim-0 sharding will be the most performant (lowest extra overhead), but we will try our best to reduce the extra overheads of sharding on a different dim.