Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.31k stars 3.38k forks source link

Cleanup FSDP integration to not require boilerplate logic #8722

Closed SeanNaren closed 1 year ago

SeanNaren commented 3 years ago

🚀 Feature

Motivated by debugging FSDP in a recent PR made by @carmocca, I think we should try clean out the interface for FSDP.

Currently FSDP supports a case where we wrap layers inside the configure_sharded_model hook, with an assumption that these layers are defined outside the hook. This is probably because in most cases the model has been defined in setup, or __init__.

This can be seen here: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py#L58-L74


class TestFSDPModel(BoringModel):
    def setup(self, stage: str) -> None:
        if stage != "fit":
            # when running stages like test, validate, and predict, we will skip setting up,
            # will directly use the module itself unless we load from checkpoint
            return
        # resetting call_configure_sharded_model_hook attribute so that we could call
        # configure sharded model
        self.call_configure_sharded_model_hook = False
        # for loading full state dict, we first need to create a new unwrapped model
        # to load state dict and then wrapping
        self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

    def configure_sharded_model(self) -> None:
        for i, layer in enumerate(self.layer):
            if i % 2 == 0:
                self.layer[i] = wrap(layer)
        self.layer = wrap(self.layer)

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        # when loading full state dict, we first need to create a new unwrapped model
        self.setup("fit")

Also included is a lot of boilerplate logic to handle the case where a user wants to load weights back into model, and we re-create the model -> load weights -> call configure sharded model again.

This is a bit unclean as we see an internal variable needing to be reset (call_configure_sharded_model_hook) and more importantly, assume that the model state hasn't been altered (which it has by FSDP which permanently flattens the parameters).

Imo we should move towards this API:

class TestFSDPModel(BoringModel):
    def __init__(self):
        self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))

    def configure_sharded_model(self) -> None:
        for i, layer in enumerate(self.layer):
            if i % 2 == 0:
                self.layer[i] = wrap(layer)
        self.layer = wrap(self.layer)

and allow this to happen for large models that take time to load into memory (are quicker one module at a time):

class TestFSDPModel(BoringModel):
    def __init__(self):
        pass

    def configure_sharded_model(self) -> None:
        self.layer = torch.nn.Sequential(
            wrap(torch.nn.Linear(32, 32)), 
            torch.nn.ReLU(), 
           wrap(torch.nn.Linear(32, 2))
        )

How to actually implement this?

Once the model has been setup, ideally we should never need to set this model up again unless the model has changed (covered in the RFC https://github.com/PyTorchLightning/pytorch-lightning/issues/8593). This would allow the model to remain the same across stages.

Given the above, I think we'll then be able to rely on primitive state dict functions of the wrapped model via FSDP: https://fairscale.readthedocs.io/en/stable/_modules/fairscale/nn/data_parallel/fully_sharded_data_parallel.html#FullyShardedDataParallel.state_dict

cc @ananthsub

ananthsub commented 3 years ago

I agree. call_configure_sharded_model as a property is fragile. it reminds me of https://github.com/PyTorchLightning/pytorch-lightning/issues/7301

where either the user can check inside of configure_sharded_model if the model is already sharded, or the framework avoids rewrapping. @SeanNaren both the new TestFSDPModel examples look far cleaner

regarding the state dict, would the plugin now wrap the LightningModule as whole with FSDP?

SeanNaren commented 3 years ago

@ananthsub this is a really good point, I realised that after we support https://github.com/PyTorchLightning/pytorch-lightning/issues/8593 then there is no reason that FSDP cannot wrap the entire module!

I am a bit unsure exactly how the logic would proceed currently, will need some investigation!

ananthsub commented 3 years ago

@SeanNaren regarding the current test example, I think this is a specific choice by the use case. If the guiding principle is to deprecate call_configure_sharded_model_hook then I think the initialization structure you've provided is the natural follow up. regardless of the property, it's a much clearer approach

fcampagnexandr commented 3 years ago

Seems to me there are some issues with the code snippets as written. I stumbled on this issue looking for information about whether I should still init the model in the constructor, or only in the hook. I think one model could be rewritten as:

class TestFSDPModel(BoringModel):
    def __init__(self, lazy_init:bool=False):
        if lazy_init:
          # need to define fields in constructor, or hook will fail.
           self.layer=None
       else:
           # Create layers right away, not super efficient with large models, but convenient for testing in isolation from trainer.
           configure_sharded_model()

    def configure_sharded_model(self) -> None:
        self.layer = torch.nn.Sequential(
            wrap(torch.nn.Linear(32, 32)), 
            torch.nn.ReLU(), 
           wrap(torch.nn.Linear(32, 2))
        )
tchaton commented 3 years ago

@SeanNaren Do you recommend not calling setup with FSDP ?

SeanNaren commented 3 years ago

@fcampagnexandr TLDR: this works:

from typing import Dict, Any

import torch
from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel

class TestFSDPModel(BoringModel):
    def __init__(self):
        super().__init__()
        self._setup_model()

    def _setup_model(self):
        self.model = torch.nn.Sequential(
            wrap(torch.nn.Linear(32, 32)),
            torch.nn.ReLU(),
            wrap(torch.nn.Linear(32, 2))
        )

    def configure_sharded_model(self) -> None:
        self.model[0] = wrap(self.model[0])
        self.model[1] = wrap(self.model[1])

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=1e-5)

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        # restores the model before FSDP wraps to
        # load the state dict, which doesn't have FSDP references.
        self._setup_model()

model = TestFSDPModel()
trainer = Trainer(plugins='fsdp', gpus=1, fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint('model.pt')
trainer.test(model, ckpt_path='model.pt')

More details and why this is wrong (especially important to @ananthsub):

The reason we have to restore the model in the on_load_checkpoint hook can be described as below:

import os

import torch
from fairscale.nn import FullyShardedDataParallel
from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(find_free_network_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
torch.distributed.init_process_group("nccl")

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = FullyShardedDataParallel(
            torch.nn.Sequential(
                FullyShardedDataParallel(torch.nn.Linear(32, 32)),
                torch.nn.ReLU(),
                FullyShardedDataParallel(torch.nn.Linear(32, 2))
            )
        )

model = MyModel()
state_dict = model.state_dict()
# crashes because `load_state_dict` hasn't been called on FSDP model!
model.load_state_dict(state_dict)

When this issue has been resolved, we will always wrap the entire module in FSDP, and the plugin keeps the same reference. This is closer to intended behaviour and solves a plethora of issues as described.

ananthsub commented 3 years ago

@SeanNaren @tchaton this is on @jjenniferdai and my mind as some of our large text model cases are having issues with CPU OOMs, which relates to model initialization and checkpoint loading (#9406)

  1. I am totally onboard with deprecating call_configure_sharded_model and all the lifecycle checks it comes with. There are 2 issues I see with it right now:
    • The Trainer has an inconsistent call order across successive functions. Sometimes we call the hook, sometimes we don't. This makes debugging really challenging. The example integration assumes fit is called before test but that's not always true.
    • It's the LightningModule which is actually wrapping the layers inside of the LightningModule, not the plugin. Therefore, the user is already the one responsible for determining whether to apply the wrap, not the plugin. This is a departure from existing plugins in Lightning and starts us down supporting generic manual parallelization. Instead of relying on call_configure_sharded_model in the model and the training type plugin, users could just as well check for isinstance(self.model, FullyShardedDataParallel)and return early if it's already wrapped. It's the same spirit of https://github.com/PyTorchLightning/pytorch-lightning/issues/8593 - but controlled by the user. TLDR: Users should implement configure_sharded_model as idempotent.

Proposal: Given that LightningModule.call_configure_sharded_model_hook is not documented anywhere on the public of the LightningModule, can we remove these all properties associated with this check without a deprecation process?

  1. configure_sharded_model is not implementation agnostic: the user needs to know whether they're using FSDP/DeepSpeed/some other technique to apply the appropriate wrapping. this makes sense since the libraries involved here can be so different.

  2. Expectations around delayed initialization & checkpoint loading. Right now we restore model states from checkpoint into the model after setup : https://github.com/PyTorchLightning/pytorch-lightning/blob/381343a79c703f2ccf1ab7c1d87400ad6e31fdf4/pytorch_lightning/trainer/trainer.py#L987-L993

For the checkpoint state to be loaded, all layers must be initialized by the time setup completes. However, configure_sharded_model runs after setup.

This means if the model state dict contains FSDP weights, the LightningModule needs to initialize FSDP before loading the checkpoint. And if the LightningModule wants to load a model state dict without FSDP weights and then configure FSDP, it needs to apply the wrapper only in configure_sharded_model.

This is confusing since:

One potential mitigation is wrapping the entire LightningModule with FSDP and then avoiding rewrapping it later. However, I'm not sure how that will play with:

Do you think a formalization of manual parallelization is an option we could pursue here? In this case:

this latter option might be pretty niche since we will it'll be more likely that all params cannot fit on a single device. otherwise users could opt for DDP Sharded + Zero redundancy optimizer.

Looking at the FSDP plugin, it's pretty minimal (some of that is due to it currently extending from DDP): https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/fully_sharded.py

but as long as we call configure_optimizers in the right spot, and save the checkpoints in the right way (either full state dict from rank 0, or shards of all ranks), then we open up a lot of flexibility for the users.

carmocca commented 1 year ago

@awaelchli Can we close this?

awaelchli commented 1 year ago

Yes, I believe all the main concerns from the issue description are resolved today.