Closed evkogs closed 2 weeks ago
Current behavior:
[rank0]: File "/home/ubuntu/efs_gpu/libs/accelerate/src/accelerate/accelerator.py", line 3340, in get_state_dict
[rank0]: with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
[rank0]: File "/home/ubuntu/efs_gpu/miniconda3/envs/diffusion_torch2.5/lib/python3.11/contextlib.py", line 144, in __exit__
[rank0]: next(self.gen)
[rank0]: File "/home/ubuntu/efs_gpu/miniconda3/envs/diffusion_torch2.5/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 835, in state_dict_type
[rank0]: FullyShardedDataParallel.set_state_dict_type(
[rank0]: File "/home/ubuntu/efs_gpu/miniconda3/envs/diffusion_torch2.5/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 711, in set_state_dict_type
[rank0]: state_dict_config_type = _state_dict_type_to_config[state_dict_type]
[rank0]: ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]: KeyError: None
Because FSDP state_dict_type function expects FSDP module:
@staticmethod
@contextlib.contextmanager
def state_dict_type(
module: nn.Module,
state_dict_type: StateDictType,
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> Generator:
"""Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.
This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of
:meth:`set_state_dict_type` for the detail.
Example::
>>> # xdoctest: +SKIP("undefined variables")
>>> model = DDP(FSDP(...))
>>> with FSDP.state_dict_type(
>>> model,
>>> StateDictType.SHARDED_STATE_DICT,
>>> ):
>>> checkpoint = model.state_dict()
Args:
module (torch.nn.Module): Root module.
state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
state_dict_config (Optional[StateDictConfig]): the model ``state_dict``
configuration for the target ``state_dict_type``.
optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer
``state_dict`` configuration for the target ``state_dict_type``.
"""
prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
module,
state_dict_type,
state_dict_config,
optim_state_dict_config,
)
yield
FullyShardedDataParallel.set_state_dict_type(
module,
prev_state_dict_settings.state_dict_type,
prev_state_dict_settings.state_dict_config,
prev_state_dict_settings.optim_state_dict_config,
)
Unwrapping before state_dict = model.state_dict()
causes incomplete checkpoint
Printing out state_dict of state_dict = fsdp_model.state_dict()
shows no fsdp in keys, but original module names.
So I guess it's unrequired, I'm using FULL_STATE_DICT type
@evkogs yes I can see why that'd be an issue. Would you like to put in the fix/refactor? (sorry it took a bit to find your issue)
@muellerzr Yeah I think I will, I need to answer 2 questions before that: 1) Using torch 2.4 and 2.5, recently introduced default behavior leads to error. Why did @alex-jw-brooks implement it this way? Did he test it? Maybe older pytorch version required it? 2) I didn't run tests thoroughly, and used it only with FULL_STATE_DICT. Maybe with other state dict types we won't end up with a correctly unwrapped original module, which is done automatically now.
Hey @evkogs - yes I tested it, but it was against 2.3; this was at the end of July and was very close to the 2.4 release. Here is the issue which has more details and a minimal repro of what the core issue it intended to fix was. It also has the accelerate config etc that I was using, what was also using FULL_STATE_DICT.
The reason for the PR was that we were seeing issues with models saved through the Huggingface Trainer when Accelerate was being enabled with FSDP, where tied weights / shared pointers were not being resolved correctly at export time when the trainer deletes duplicated weights. This was causing problems when loading tuned models into optimized inference engines that have their own weight loading mechanisms, e.g., VLLM. With Pytorch 2.3, adding the unwrap call here caused the tied weights to resolve to the same data pointer, which allowed the HF trainer to mark them as duplicated weights and remove them when saving the model
Yeah it's tricky. But with torch 2.5 I tried to unwrap inside a handler before model.state_dict() and ended up with incomplete checkpoint. However I face various issues with corrupted checkpoints (mostly of untrained modules) when using accelerate with FSDP, and so far not a single corrupted state_dict of target model.
I am having the same problem with accelerate==0.34.2
and torch==2.4.1
; this following function for me at the moment
def save_with_accelerate(
accelerator: Accelerator,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizer,
output_dir: str,
use_lora: bool = False,
) -> None:
if accelerator.is_main_process and (not os.path.isdir(output_dir)):
os.mkdir(output_dir)
if use_lora:
# When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
# and has its own save_pretrained function for only saving lora modules.
# We have to manually specify the is_main_process outside the save_pretrained function.
if accelerator.is_main_process:
model.base_model = model.base_model.merge_and_unload()
torch.save(model.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
else:
# don't use safetensors for saving for now
if isinstance(model, PreTrainedModel): # TODO: this needs to be tested
model.save_pretrained(
output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
else:
torch.save(model.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
if accelerator.is_main_process:
tokenizer.save_pretrained(output_dir)
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
use accelerator.get_state_dict()
Expected behavior
Bugged functionality was added in #2959, so instead of first unwrapping model and then using
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
, which would break FSDP handler, we should doI'm not 100% sure that's the fix, but it seems to work