Closed kaushikb11 closed 2 years ago
Are these two implementations synced? Is there any benefit due to the speed of development we can see by leveraging the FairScale implementation? cc @ananthsub who might be able to follow up the right threads.
how about adding another strategy NativeFSDP
that comes from PT instead of replacing it? Since they have mentioned that fairscale
will still be developed and some of its features will be upstreamed to PT in the future.
In the near future, FairScale FSDP will stay in the FairScale repository for research projects, while
generic and widely adopted features will be upstreamed to PyTorch incrementally and hardened
accordingly.
how about adding another strategy
NativeFSDP
that comes from PT instead of replacing it? Since they have mentioned thatfairscale
will still be developed and some of its features will be upstreamed to PT in the future.
not sure if this is better, if not a bit confusing for users who just want to use FSDP. I think I'm leaning towards using the FSDP found in FairScale, which seems to be the faster improving grounds. Since all these wrappers are relatively new, I think it's more important to rely on the one that will iterate the fastest.
There are some benefits in adding in another strategy for torch.distributed.fsdp flavor of FSDP. After discussing with folks internally on facebook's pytorch team here are some reasons to add a new strategy:
We could do one of the following
class DDPFullyShardedStrategy2(DDPFullyShardedStrategy):
strategy_name = “ddp_fully_sharded2”
def __init__(
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
min_num_params: int = 1e8,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
)
Cons: With this approach, we'll have to keep both DDPFullyShardedStrategy2 and DDPFullyShardedStrategy implementations in sync w.r.t to their parameters. Since not all parameters for DDPFullyShardedStrategy2 could be used in DDPFullyShardedStrategy and vice versa
class DDPFullyShardedStrategy2(ParallelStrategy):
strategy_name = “ddp_fully_sharded2”
def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
min_num_params: int = 1e8,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
)
Cons: DDPFullyShardedStrategy2 and DDPFullyShardedStrategy non FSDP related logic could fall out of sync.
cc @ananthsub
I'm still a bit hesitant as we're throwing another version into the mix. If a new user looks for model parallel lightning, they already are given quite a bit to choose from (deepspeed, fsdp, ddp_sharded and bagua). Including a ddp_fully_sharded2
in the mix may confuse.
If we do come to the consensus of having a separate sharded plugin, I'd suggest naming it fsdp_native
. Since FSDP seems mostly an internal effort at Meta to improve and build as of yet, I do incline to trust the opinion of the meta engineers!
from @SeanNaren I think it's more important to rely on the one that will iterate the fastest.
Agree, TorchFSDP should be the future, and the effort in fairscale and TorchFSDP presumably will merge in the future. Performance wise, torchFSDP has significant improvement. Feature coverage side, I think fairscale have more but TorchFSDP is catching up
I think we want to have both fsdp
(fairscale) and native_fsdp
for now. We probably shouldn't choose one default, because we want to encourage ppl use torchFSDP (so fairscale shouldn't be the default), we also don't want silently swap for user (torch fsdp shouldn't be the default).
@sisilmehta2000 personally I prefer option 2. As you mentioned not all params are compatible between the two FSDP, neither does the APIs. Especially torch fsdp is rapidly developing, I don't think we should tie the two together.
We probably shouldn't choose one default, because we want to encourage ppl use torchFSDP (so fairscale shouldn't be the default), we also don't want silently swap for user (torch fsdp shouldn't be the default).
There is no need to specify a default, as both of these are opted into by users. As long as users have an easy way to select between them, there is no conflict here.
fsdp
as a shortname is already taken by fairscale: changing what this shortname points to will break backward compatibility, so I strongly believe we should not change this. If/when in the future the fairscale and PT versions are merged, then we can re-route this to rely on the PyTorch distributed version with the appropriate advance warnings for users.
native_fsdp
as a short name sounds good to me, especially since native
is what's used already in Lightning for amp_backend
to distinguish PyTorch AMP vs Apex AMP. either this or torch_fsdp
are short and clarify the backing implementation.
@sisilmehta2000 personally I prefer option 2. As you mentioned not all params are compatible between the two FSDP, neither does the APIs. Especially torch fsdp is rapidly developing, I don't think we should tie the two together.
Major +1 to this. As much as possible we should limit use of inheritance for code sharing. Given the APIs are simultaneously evolving, and possibly in different directions, having wholly separate implementations reduces the risk of bugs / unwanted side effects from parent-child inheritance. It's the same motivation as https://github.com/PyTorchLightning/pytorch-lightning/issues/11863
@kaushikb11 hey Kaushik! We have a version of the "native_fsdp.py" plugin we've been using internally @ meta. Wondering if I can go ahead and submit a PR against his issue?
@sisilmehta2000 Absolutely! Go for it. I agree that we could introduce a new native_fsdp
strategy. Feel free to ask if you face any issues. Assigning you to the issue.
I suggest using fsdp_native
over native_fsdp
so they appear together when sorted.
Folks, I've put up the PR here: https://github.com/PyTorchLightning/pytorch-lightning/pull/12447
🚀 Feature
Motivation
Blog: https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/
We currently use the fairscale's implementation.
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @SeanNaren @awaelchli @rohitgr7