Open mayank31398 opened 1 week ago
Can you give an example of this use case for us to understand the motivation better?
Anyway, we are working on this https://github.com/pytorch/pytorch/pull/136221/, but it will take some time to mature.
Concretely, dim-0 sharding will be the most performant (lowest extra overhead), but we will try our best to reduce the extra overheads of sharding on a different dim.
🚀 The feature, motivation and pitch
I created a tensor of shape (2, N) in a module wrap it with FSDP on 2 GPUs following are the shapes: GPU 0, 1 -> (1, N) GPU 2 to 7 -> (0, N)
this is a big limitation if the shape of tensors in dim=0 are not big enough
Is it possible to modify FSDP-2 like this:
if dim=0 is big enough: Shard(0) else: iterate to find the largest dim Shard(max dim)
Alternatives
No response
Additional context
No response
cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @kwen2501 @chauhang