🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
Compiled FSDP model uses use_orig_params=True and work on orig parameters. Thus, the state_dict in the saved ckpt is inconsistent as the non-compiled ckpt.
Despite there are some works has been done to automatic some of these inconsistency, yet those does not work in our fms-to-hf conversion script as we hardcode the key mapping.
We should add a flag on if the ckpt is compiled checkpoint, and use load-and-off-load way to massage the state dict to make it work with our script.
Compiled FSDP model uses
use_orig_params=True
and work on orig parameters. Thus, the state_dict in the saved ckpt is inconsistent as the non-compiled ckpt.Despite there are some works has been done to automatic some of these inconsistency, yet those does not work in our fms-to-hf conversion script as we hardcode the key mapping.
We should add a flag on if the ckpt is compiled checkpoint, and use load-and-off-load way to massage the state dict to make it work with our script.