foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
114 stars 18 forks source link

add support for converting compiled model to hf #61

Closed lchu-ibm closed 3 months ago

lchu-ibm commented 3 months ago

To address https://github.com/foundation-model-stack/fms-fsdp/issues/59

To convert compiled model, we can run

python fms_to_hf.py --model_variant 7b --compiled --load_path /fsx/lchu/...

To convert non-compiled model, we can run

python fms_to_hf.py --model_variant 7b --nocompiled --load_path /fsx/lchu/...
nairbv commented 3 months ago

@JRosenkranz I wonder if this is something that should be included in fms adapters? I didn't realize compiled models may be saved in a different way?

@lchu-ibm alternatively, would it be possible to just change the way we save FSDP checkpoints for compiled models? it seems we really only want the _orig_mod , so we could just save that as 'models' when writing them out?

JRosenkranz commented 3 months ago

@JRosenkranz I wonder if this is something that should be included in fms adapters? I didn't realize compiled models may be saved in a different way?

@nairbv Yes, I didn't realize this either. If the compiled models are saved in a different way, it might make sense to include an adapter for it provided through the source in get_model. I do agree though, it may make sense to just save the models in a different way. Another option is we can implicitly determine this and run the adapter that way?

lchu-ibm commented 3 months ago

@nairbv If we save it with dropping the intermediate key, then loading back would become another problem.

it is "cross join" scenario: if we save compiled and load as compile, then no issue, as both expect that intermediate key. if we save non-compiled and load as non-compile, then no issue. if we save and load in different mode, then we need to massage the ckpt. (add or drop that intermediate key)

torch already has some PRs on automatically massage the ckpt by identifying this extra intermediate thing and add/remove it as needed. sort of as @JRosenkranz mentioned about the "implicit" way.

daviswer commented 3 months ago

Does this code handle the cross-loading cases? I see that with a compiled model, the script generates a compile-version state dict, then calls the usual torch.distributed._shard.checkpoint.load_state_dict() to update it. Does this function handle checkpoints in either format? (I guess this might reduce to the same question as above)

lchu-ibm commented 3 months ago

@daviswer for this PR, we aim at providing compile FSDP -> single cpu and non-compile FSDP -> single gpu.

The compiled one is done by overloading (add intermediate key when loading) and offloading (drop intermediate key when offloading).

Of course, there are many other scenarios. For the full scope, it might be handled by FMS or somewhere else.