Open shellysheynin opened 2 years ago
@shellysheynin Thanks you for opening the issue and letting us know how you fixed it! We do test for shared and expert params in our unit tests so I am trying to understand how your model may be different than the one used in the test.
Hi, i'm trying to train fully sharded transformer. At the beginning, I started to train the model with use_shard_state=False, but it failed when tried to save the checkpoint, since there are several flatten params (for the encoder module and the decoder module), and with use_shard_state=False, it expects to see only one flatten param. am I right? So after changing to use_shard_state=True it worked, and saved the state in several files (32 files as the number of GPUS). However, when I tried to load the model from checkpoint, it crashed in the function "consolidate_shard_weights" :
here is what the model is trying to do:
The problem occurs when it tires to copy shared params, that indeed are shared in the model - the weights of the decoder.embed_tokens to the weights of decoder.output_projection. However, the syntax is wrong cause embed.token.weight is not in the dictionary, but decoder.embed_token.weight is. I fixed this bug by adding the "fsdp_path" to the src_path and the dest_path.