huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.92k stars 966 forks source link

FSDP miconfigurations #3061

Closed evkogs closed 2 weeks ago

evkogs commented 2 months ago

System Info

Latest main version, torch nightly, cuda 12.6

Information

Tasks

Reproduction

use accelerator.get_state_dict()

Expected behavior

Bugged functionality was added in #2959, so instead of first unwrapping model and then using with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):, which would break FSDP handler, we should do

elif self.distributed_type == DistributedType.FSDP:
    from torch.distributed.fsdp import FullStateDictConfig, StateDictType
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

    full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):

        print(f"using this config: {StateDictType.FULL_STATE_DICT}")

        state_dict = model.state_dict()

I'm not 100% sure that's the fix, but it seems to work

evkogs commented 2 months ago

Current behavior:

[rank0]:   File "/home/ubuntu/efs_gpu/libs/accelerate/src/accelerate/accelerator.py", line 3340, in get_state_dict
[rank0]:     with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
[rank0]:   File "/home/ubuntu/efs_gpu/miniconda3/envs/diffusion_torch2.5/lib/python3.11/contextlib.py", line 144, in __exit__
[rank0]:     next(self.gen)
[rank0]:   File "/home/ubuntu/efs_gpu/miniconda3/envs/diffusion_torch2.5/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 835, in state_dict_type
[rank0]:     FullyShardedDataParallel.set_state_dict_type(
[rank0]:   File "/home/ubuntu/efs_gpu/miniconda3/envs/diffusion_torch2.5/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 711, in set_state_dict_type
[rank0]:     state_dict_config_type = _state_dict_type_to_config[state_dict_type]
[rank0]:                              ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]: KeyError: None

Because FSDP state_dict_type function expects FSDP module:

    @staticmethod
    @contextlib.contextmanager
    def state_dict_type(
        module: nn.Module,
        state_dict_type: StateDictType,
        state_dict_config: Optional[StateDictConfig] = None,
        optim_state_dict_config: Optional[OptimStateDictConfig] = None,
    ) -> Generator:
        """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module.

        This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of
        :meth:`set_state_dict_type` for the detail.

        Example::

            >>> # xdoctest: +SKIP("undefined variables")
            >>> model = DDP(FSDP(...))
            >>> with FSDP.state_dict_type(
            >>>     model,
            >>>     StateDictType.SHARDED_STATE_DICT,
            >>> ):
            >>>     checkpoint = model.state_dict()

        Args:
            module (torch.nn.Module): Root module.
            state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
            state_dict_config (Optional[StateDictConfig]): the model ``state_dict``
                configuration for the target ``state_dict_type``.
            optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer
               ``state_dict`` configuration for the target ``state_dict_type``.
        """
        prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
            module,
            state_dict_type,
            state_dict_config,
            optim_state_dict_config,
        )
        yield
        FullyShardedDataParallel.set_state_dict_type(
            module,
            prev_state_dict_settings.state_dict_type,
            prev_state_dict_settings.state_dict_config,
            prev_state_dict_settings.optim_state_dict_config,
        )
evkogs commented 2 months ago

Unwrapping before state_dict = model.state_dict() causes incomplete checkpoint

Printing out state_dict of state_dict = fsdp_model.state_dict() shows no fsdp in keys, but original module names.

So I guess it's unrequired, I'm using FULL_STATE_DICT type

muellerzr commented 2 months ago

@evkogs yes I can see why that'd be an issue. Would you like to put in the fix/refactor? (sorry it took a bit to find your issue)

evkogs commented 2 months ago

@muellerzr Yeah I think I will, I need to answer 2 questions before that: 1) Using torch 2.4 and 2.5, recently introduced default behavior leads to error. Why did @alex-jw-brooks implement it this way? Did he test it? Maybe older pytorch version required it? 2) I didn't run tests thoroughly, and used it only with FULL_STATE_DICT. Maybe with other state dict types we won't end up with a correctly unwrapped original module, which is done automatically now.

alex-jw-brooks commented 2 months ago

Hey @evkogs - yes I tested it, but it was against 2.3; this was at the end of July and was very close to the 2.4 release. Here is the issue which has more details and a minimal repro of what the core issue it intended to fix was. It also has the accelerate config etc that I was using, what was also using FULL_STATE_DICT.

The reason for the PR was that we were seeing issues with models saved through the Huggingface Trainer when Accelerate was being enabled with FSDP, where tied weights / shared pointers were not being resolved correctly at export time when the trainer deletes duplicated weights. This was causing problems when loading tuned models into optimized inference engines that have their own weight loading mechanisms, e.g., VLLM. With Pytorch 2.3, adding the unwrap call here caused the tied weights to resolve to the same data pointer, which allowed the HF trainer to mark them as duplicated weights and remove them when saving the model

evkogs commented 2 months ago

Yeah it's tricky. But with torch 2.5 I tried to unwrap inside a handler before model.state_dict() and ended up with incomplete checkpoint. However I face various issues with corrupted checkpoints (mostly of untrained modules) when using accelerate with FSDP, and so far not a single corrupted state_dict of target model.

zhichaoxu-shufe commented 1 month ago

I am having the same problem with accelerate==0.34.2 and torch==2.4.1; this following function for me at the moment

def save_with_accelerate(
    accelerator: Accelerator,
    model: torch.nn.Module,
    tokenizer: PreTrainedTokenizer,
    output_dir: str,
    use_lora: bool = False,
    ) -> None:
    if accelerator.is_main_process and (not os.path.isdir(output_dir)):
        os.mkdir(output_dir)

    if use_lora:
        # When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
        # and has its own save_pretrained function for only saving lora modules.
        # We have to manually specify the is_main_process outside the save_pretrained function.
        if accelerator.is_main_process:
            model.base_model = model.base_model.merge_and_unload()
            torch.save(model.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))
    else:
        # don't use safetensors for saving for now
        if isinstance(model, PreTrainedModel):  # TODO: this needs to be tested
            model.save_pretrained(
                output_dir,
                is_main_process=accelerator.is_main_process,
                save_function=accelerator.save,
            )
        else:
            torch.save(model.state_dict(), os.path.join(output_dir, 'pytorch_model.bin'))

    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
github-actions[bot] commented 3 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.