Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

Use PyTorch's Fully Sharded Data Parallel API with 1.11+ #12334

Closed kaushikb11 closed 2 years ago

kaushikb11 commented 2 years ago

🚀 Feature

Motivation

Blog: https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/

We currently use the fairscale's implementation.


If you enjoy Lightning, check out our other projects! ⚡

cc @SeanNaren @awaelchli @rohitgr7

SeanNaren commented 2 years ago

Are these two implementations synced? Is there any benefit due to the speed of development we can see by leveraging the FairScale implementation? cc @ananthsub who might be able to follow up the right threads.

rohitgr7 commented 2 years ago

how about adding another strategy NativeFSDP that comes from PT instead of replacing it? Since they have mentioned that fairscale will still be developed and some of its features will be upstreamed to PT in the future.

In the near future, FairScale FSDP will stay in the FairScale repository for research projects, while
generic and widely adopted features will be upstreamed to PyTorch incrementally and hardened
accordingly.
SeanNaren commented 2 years ago

how about adding another strategy NativeFSDP that comes from PT instead of replacing it? Since they have mentioned that fairscale will still be developed and some of its features will be upstreamed to PT in the future.

not sure if this is better, if not a bit confusing for users who just want to use FSDP. I think I'm leaning towards using the FSDP found in FairScale, which seems to be the faster improving grounds. Since all these wrappers are relatively new, I think it's more important to rely on the one that will iterate the fastest.

sisilmehta2000 commented 2 years ago

There are some benefits in adding in another strategy for torch.distributed.fsdp flavor of FSDP. After discussing with folks internally on facebook's pytorch team here are some reasons to add a new strategy:

  1. Torch.distributed.fsdp doesn’t currently implement all the features of the fairscale version
  2. We don't want the torch distributed API to feel restricted based on current integrations with fairscale. By keeping them separate, we'll adapt to whatever recommendation the torch.distributed team provides
  3. It's cheap/easy for us to add a new strategy class
  4. We want to have people start using the new torch.distributed.fsdp version without introducing a breaking change to those using the fairscale version

We could do one of the following

  1. Inherit from strategies/fully_sharded.py
class DDPFullyShardedStrategy2(DDPFullyShardedStrategy):
        strategy_name = “ddp_fully_sharded2”

       def __init__(
          accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
          min_num_params: int = 1e8,
          parallel_devices: Optional[List[torch.device]] = None,
          cluster_environment: Optional[ClusterEnvironment] = None,
          checkpoint_io: Optional[CheckpointIO] = None,
          precision_plugin: Optional[PrecisionPlugin] = None,
       )

Cons: With this approach, we'll have to keep both DDPFullyShardedStrategy2 and DDPFullyShardedStrategy implementations in sync w.r.t to their parameters. Since not all parameters for DDPFullyShardedStrategy2 could be used in DDPFullyShardedStrategy and vice versa

  1. Inherit from strategies/parallel_strategy.py
class DDPFullyShardedStrategy2(ParallelStrategy):
        strategy_name = “ddp_fully_sharded2”

   def __init__(
       self,
       accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
       min_num_params: int = 1e8,
       parallel_devices: Optional[List[torch.device]] = None,
       cluster_environment: Optional[ClusterEnvironment] = None,
       checkpoint_io: Optional[CheckpointIO] = None,
       precision_plugin: Optional[PrecisionPlugin] = None,
   )

Cons: DDPFullyShardedStrategy2 and DDPFullyShardedStrategy non FSDP related logic could fall out of sync.

cc @ananthsub

SeanNaren commented 2 years ago

I'm still a bit hesitant as we're throwing another version into the mix. If a new user looks for model parallel lightning, they already are given quite a bit to choose from (deepspeed, fsdp, ddp_sharded and bagua). Including a ddp_fully_sharded2 in the mix may confuse.

If we do come to the consensus of having a separate sharded plugin, I'd suggest naming it fsdp_native. Since FSDP seems mostly an internal effort at Meta to improve and build as of yet, I do incline to trust the opinion of the meta engineers!

four4fish commented 2 years ago

from @SeanNaren I think it's more important to rely on the one that will iterate the fastest.

Agree, TorchFSDP should be the future, and the effort in fairscale and TorchFSDP presumably will merge in the future. Performance wise, torchFSDP has significant improvement. Feature coverage side, I think fairscale have more but TorchFSDP is catching up

I think we want to have both fsdp (fairscale) and native_fsdp for now. We probably shouldn't choose one default, because we want to encourage ppl use torchFSDP (so fairscale shouldn't be the default), we also don't want silently swap for user (torch fsdp shouldn't be the default).

@sisilmehta2000 personally I prefer option 2. As you mentioned not all params are compatible between the two FSDP, neither does the APIs. Especially torch fsdp is rapidly developing, I don't think we should tie the two together.

ananthsub commented 2 years ago

We probably shouldn't choose one default, because we want to encourage ppl use torchFSDP (so fairscale shouldn't be the default), we also don't want silently swap for user (torch fsdp shouldn't be the default).

There is no need to specify a default, as both of these are opted into by users. As long as users have an easy way to select between them, there is no conflict here.

fsdp as a shortname is already taken by fairscale: changing what this shortname points to will break backward compatibility, so I strongly believe we should not change this. If/when in the future the fairscale and PT versions are merged, then we can re-route this to rely on the PyTorch distributed version with the appropriate advance warnings for users.

native_fsdp as a short name sounds good to me, especially since native is what's used already in Lightning for amp_backend to distinguish PyTorch AMP vs Apex AMP. either this or torch_fsdp are short and clarify the backing implementation.

@sisilmehta2000 personally I prefer option 2. As you mentioned not all params are compatible between the two FSDP, neither does the APIs. Especially torch fsdp is rapidly developing, I don't think we should tie the two together.

Major +1 to this. As much as possible we should limit use of inheritance for code sharing. Given the APIs are simultaneously evolving, and possibly in different directions, having wholly separate implementations reduces the risk of bugs / unwanted side effects from parent-child inheritance. It's the same motivation as https://github.com/PyTorchLightning/pytorch-lightning/issues/11863

sisilmehta2000 commented 2 years ago

@kaushikb11 hey Kaushik! We have a version of the "native_fsdp.py" plugin we've been using internally @ meta. Wondering if I can go ahead and submit a PR against his issue?

kaushikb11 commented 2 years ago

@sisilmehta2000 Absolutely! Go for it. I agree that we could introduce a new native_fsdp strategy. Feel free to ask if you face any issues. Assigning you to the issue.

carmocca commented 2 years ago

I suggest using fsdp_native over native_fsdp so they appear together when sorted.

sisilmehta2000 commented 2 years ago

Folks, I've put up the PR here: https://github.com/PyTorchLightning/pytorch-lightning/pull/12447