Closed lchu-ibm closed 3 months ago
@JRosenkranz I wonder if this is something that should be included in fms adapters? I didn't realize compiled models may be saved in a different way?
@lchu-ibm alternatively, would it be possible to just change the way we save FSDP checkpoints for compiled models? it seems we really only want the _orig_mod
, so we could just save that as 'models' when writing them out?
@JRosenkranz I wonder if this is something that should be included in fms adapters? I didn't realize compiled models may be saved in a different way?
@nairbv Yes, I didn't realize this either. If the compiled models are saved in a different way, it might make sense to include an adapter for it provided through the source in get_model
. I do agree though, it may make sense to just save the models in a different way. Another option is we can implicitly determine this and run the adapter that way?
@nairbv If we save it with dropping the intermediate key, then loading back would become another problem.
it is "cross join" scenario: if we save compiled and load as compile, then no issue, as both expect that intermediate key. if we save non-compiled and load as non-compile, then no issue. if we save and load in different mode, then we need to massage the ckpt. (add or drop that intermediate key)
torch already has some PRs on automatically massage the ckpt by identifying this extra intermediate thing and add/remove it as needed. sort of as @JRosenkranz mentioned about the "implicit" way.
Does this code handle the cross-loading cases? I see that with a compiled model, the script generates a compile-version state dict, then calls the usual torch.distributed._shard.checkpoint.load_state_dict()
to update it. Does this function handle checkpoints in either format? (I guess this might reduce to the same question as above)
@daviswer for this PR, we aim at providing compile FSDP -> single cpu
and non-compile FSDP -> single gpu
.
The compiled one is done by overloading (add intermediate key when loading) and offloading (drop intermediate key when offloading).
Of course, there are many other scenarios. For the full scope, it might be handled by FMS or somewhere else.
To address https://github.com/foundation-model-stack/fms-fsdp/issues/59
To convert compiled model, we can run
To convert non-compiled model, we can run