facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.17k stars 279 forks source link

failed loading state dict with used_shard_state=True #815

Open shellysheynin opened 2 years ago

shellysheynin commented 2 years ago

Hi, i'm trying to train fully sharded transformer. At the beginning, I started to train the model with use_shard_state=False, but it failed when tried to save the checkpoint, since there are several flatten params (for the encoder module and the decoder module), and with use_shard_state=False, it expects to see only one flatten param. am I right? So after changing to use_shard_state=True it worked, and saved the state in several files (32 files as the number of GPUS). However, when I tried to load the model from checkpoint, it crashed in the function "consolidate_shard_weights" :

  File "/private/home/shellysheynin/projects/fairseq-py/fairseq/checkpoint_utils.py", line 508, in load_model_ensemble_and_task
    consolidated_model_state = FSDP.consolidate_shard_weights(
  File "/private/home/shellysheynin/.conda/envs/fairseq_dalle/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1706, in consolidate_shard_weights
    consolidated_weights[dest_path] = consolidated_weights[src_path]
KeyError: 'embed_tokens.weight'

here is what the model is trying to do:

for src_path, dest_path in metadata["shared_param_info"]:
    consolidated_weights[dest_path] = consolidated_weights[src_path]

The problem occurs when it tires to copy shared params, that indeed are shared in the model - the weights of the decoder.embed_tokens to the weights of decoder.output_projection. However, the syntax is wrong cause embed.token.weight is not in the dictionary, but decoder.embed_token.weight is. I fixed this bug by adding the "fsdp_path" to the src_path and the dest_path.

for src_path, dest_path in metadata["shared_param_info"]:
    fsdp_path = metadata["fsdp_path"]
    full_src_path = fsdp_path + "." + src_path
    full_dest_path = fsdp_path + "." + dest_path
    consolidated_weights[full_dest_path]=consolidated_weights[full_src_path] 
anj-s commented 2 years ago

@shellysheynin Thanks you for opening the issue and letting us know how you fixed it! We do test for shared and expert params in our unit tests so I am trying to understand how your model may be different than the one used in the test.