facebookresearch / audiocraft

Audiocraft is a library for audio processing and generation with deep learning. It features the state-of-the-art EnCodec audio compressor / tokenizer, along with MusicGen, a simple and controllable music generation LM with textual and melodic conditioning.
MIT License
20.7k stars 2.11k forks source link

AttributeError Encountered When Using FSDP #358

Open YutoNishimura-v2 opened 10 months ago

YutoNishimura-v2 commented 10 months ago

Hello,

I am attempting to train the medium model of musicgen using FSDP. I am simply using the following command:

dora run -d [other options see training docs] fsdp.use=true autocast=false

However, I encountered the following error:

  File "/home/user/audiocraft/lib/python3.10/site-packages/flashy/state.py", line 42, in load_state_dict
    attr.load_state_dict(state)
Traceback (most recent call last):
  File "/path/to/audiocraft/audiocraft/optim/fsdp.py", line 177, in load_state_dict
    purge_fsdp(self)
  File "/path/to/audiocraft/audiocraft/optim/fsdp.py", line 132, in purge_fsdp
    handles = module._handles
  File "/home/user/audiocraft/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 529, in __getattr__
    return getattr(self._fsdp_wrapped_module, name)
  File "/home/user/audiocraft/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'LMModel' object has no attribute '_handles'

Upon inspecting the PyTorch implementation, I don't believe _handles is a list. The problem was resolved after I changed the code to the following (I'm not sure if this is the expected behavior):

def purge_fsdp(model: FSDP):
    """Purge the FSDP cached shard inside the model. This should
    allow setting the best state or switching to the EMA.
    """
    from torch.distributed.fsdp._runtime_utils import _reshard  # type: ignore

    for module in FSDP.fsdp_modules(model):
        handle = module._handle
        if not handle:
            continue
        unsharded_flat_param = handle._get_padded_unsharded_flat_param()
        storage_size: int = unsharded_flat_param._typed_storage()._size()  # type: ignore
        if storage_size == 0:
            continue
        _reshard(module, handle, True)

Could this be a bug due to different PyTorch versions? I am using version 2.1.0+cu121. My Python version is 3.10.10.

Looking forward to your response. Thank you.

YutoNishimura-v2 commented 10 months ago

https://github.com/pytorch/pytorch/blob/v2.0.0/torch/distributed/fsdp/_runtime_utils.py#L334

I found that in torch v2.0.0, _handles is used. Therefore, this is the version problem, I think.

nateraw commented 8 months ago

Can confirm your patch works for me too @YutoNishimura-v2.