Closed SeanNaren closed 1 year ago
I agree. call_configure_sharded_model
as a property is fragile. it reminds me of https://github.com/PyTorchLightning/pytorch-lightning/issues/7301
where either the user can check inside of configure_sharded_model
if the model is already sharded, or the framework avoids rewrapping.
@SeanNaren both the new TestFSDPModel
examples look far cleaner
regarding the state dict, would the plugin now wrap the LightningModule as whole with FSDP?
@ananthsub this is a really good point, I realised that after we support https://github.com/PyTorchLightning/pytorch-lightning/issues/8593 then there is no reason that FSDP cannot wrap the entire module!
I am a bit unsure exactly how the logic would proceed currently, will need some investigation!
@SeanNaren regarding the current test example, I think this is a specific choice by the use case. If the guiding principle is to deprecate call_configure_sharded_model_hook
then I think the initialization structure you've provided is the natural follow up. regardless of the property, it's a much clearer approach
Seems to me there are some issues with the code snippets as written. I stumbled on this issue looking for information about whether I should still init the model in the constructor, or only in the hook. I think one model could be rewritten as:
class TestFSDPModel(BoringModel):
def __init__(self, lazy_init:bool=False):
if lazy_init:
# need to define fields in constructor, or hook will fail.
self.layer=None
else:
# Create layers right away, not super efficient with large models, but convenient for testing in isolation from trainer.
configure_sharded_model()
def configure_sharded_model(self) -> None:
self.layer = torch.nn.Sequential(
wrap(torch.nn.Linear(32, 32)),
torch.nn.ReLU(),
wrap(torch.nn.Linear(32, 2))
)
@SeanNaren Do you recommend not calling setup with FSDP ?
@fcampagnexandr TLDR: this works:
from typing import Dict, Any
import torch
from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel
class TestFSDPModel(BoringModel):
def __init__(self):
super().__init__()
self._setup_model()
def _setup_model(self):
self.model = torch.nn.Sequential(
wrap(torch.nn.Linear(32, 32)),
torch.nn.ReLU(),
wrap(torch.nn.Linear(32, 2))
)
def configure_sharded_model(self) -> None:
self.model[0] = wrap(self.model[0])
self.model[1] = wrap(self.model[1])
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=1e-5)
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# restores the model before FSDP wraps to
# load the state dict, which doesn't have FSDP references.
self._setup_model()
model = TestFSDPModel()
trainer = Trainer(plugins='fsdp', gpus=1, fast_dev_run=True)
trainer.fit(model)
trainer.save_checkpoint('model.pt')
trainer.test(model, ckpt_path='model.pt')
More details and why this is wrong (especially important to @ananthsub):
The reason we have to restore the model in the on_load_checkpoint
hook can be described as below:
import os
import torch
from fairscale.nn import FullyShardedDataParallel
from pytorch_lightning.plugins.environments.lightning_environment import find_free_network_port
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(find_free_network_port())
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
torch.distributed.init_process_group("nccl")
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = FullyShardedDataParallel(
torch.nn.Sequential(
FullyShardedDataParallel(torch.nn.Linear(32, 32)),
torch.nn.ReLU(),
FullyShardedDataParallel(torch.nn.Linear(32, 2))
)
)
model = MyModel()
state_dict = model.state_dict()
# crashes because `load_state_dict` hasn't been called on FSDP model!
model.load_state_dict(state_dict)
When this issue has been resolved, we will always wrap the entire module in FSDP, and the plugin keeps the same reference. This is closer to intended behaviour and solves a plethora of issues as described.
@SeanNaren @tchaton this is on @jjenniferdai and my mind as some of our large text model cases are having issues with CPU OOMs, which relates to model initialization and checkpoint loading (#9406)
call_configure_sharded_model
and all the lifecycle checks it comes with. There are 2 issues I see with it right now:
fit
is called before test
but that's not always true.call_configure_sharded_model
in the model and the training type plugin, users could just as well check for isinstance(self.model, FullyShardedDataParallel)
and return early if it's already wrapped. It's the same spirit of https://github.com/PyTorchLightning/pytorch-lightning/issues/8593 - but controlled by the user. TLDR: Users should implement configure_sharded_model
as idempotent.Proposal: Given that LightningModule.call_configure_sharded_model_hook
is not documented anywhere on the public of the LightningModule, can we remove these all properties associated with this check without a deprecation process?
configure_sharded_model
is not implementation agnostic: the user needs to know whether they're using FSDP/DeepSpeed/some other technique to apply the appropriate wrapping. this makes sense since the libraries involved here can be so different.
Expectations around delayed initialization & checkpoint loading.
Right now we restore model states from checkpoint into the model after setup
: https://github.com/PyTorchLightning/pytorch-lightning/blob/381343a79c703f2ccf1ab7c1d87400ad6e31fdf4/pytorch_lightning/trainer/trainer.py#L987-L993
For the checkpoint state to be loaded, all layers must be initialized by the time setup
completes. However, configure_sharded_model
runs after setup.
This means if the model state dict contains FSDP weights, the LightningModule needs to initialize FSDP before loading the checkpoint. And if the LightningModule wants to load a model state dict without FSDP weights and then configure FSDP, it needs to apply the wrapper only in configure_sharded_model
.
This is confusing since:
model_sharded_context
only around configure_sharded_model
and not around setup
. I'm not sure if applying the context around both hooks is viable because we might want to do one without the other. One potential mitigation is wrapping the entire LightningModule with FSDP and then avoiding rewrapping it later. However, I'm not sure how that will play with:
MyLightningModule.load_from_checkpoint
because now we'd need to call the FSDP(LightningModule).load_state_dict
instead? But this FSDP-wrapped LightningModule isn't visible to the user because it's an implementation detail of the trainer. DP/DDP don't face this issue since they're not destructive.Do you think a formalization of manual parallelization is an option we could pursue here? In this case:
setup
setup
, load checkpoint, and then shard in configure_sharded_model.this latter option might be pretty niche since we will it'll be more likely that all params cannot fit on a single device. otherwise users could opt for DDP Sharded + Zero redundancy optimizer.
Looking at the FSDP plugin, it's pretty minimal (some of that is due to it currently extending from DDP): https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/training_type/fully_sharded.py
but as long as we call configure_optimizers
in the right spot, and save the checkpoints in the right way (either full state dict from rank 0, or shards of all ranks), then we open up a lot of flexibility for the users.
@awaelchli Can we close this?
Yes, I believe all the main concerns from the issue description are resolved today.
🚀 Feature
Motivated by debugging FSDP in a recent PR made by @carmocca, I think we should try clean out the interface for FSDP.
Currently FSDP supports a case where we wrap layers inside the
configure_sharded_model
hook, with an assumption that these layers are defined outside the hook. This is probably because in most cases the model has been defined in setup, or__init__
.This can be seen here: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py#L58-L74
Also included is a lot of boilerplate logic to handle the case where a user wants to load weights back into model, and we re-create the model -> load weights -> call configure sharded model again.
This is a bit unclean as we see an internal variable needing to be reset (
call_configure_sharded_model_hook
) and more importantly, assume that the model state hasn't been altered (which it has by FSDP which permanently flattens the parameters).Imo we should move towards this API:
and allow this to happen for large models that take time to load into memory (are quicker one module at a time):
How to actually implement this?
Once the model has been setup, ideally we should never need to set this model up again unless the model has changed (covered in the RFC https://github.com/PyTorchLightning/pytorch-lightning/issues/8593). This would allow the model to remain the same across stages.
Given the above, I think we'll then be able to rely on primitive state dict functions of the wrapped model via FSDP: https://fairscale.readthedocs.io/en/stable/_modules/fairscale/nn/data_parallel/fully_sharded_data_parallel.html#FullyShardedDataParallel.state_dict
cc @ananthsub