Closed YanjunChen329 closed 1 year ago
This pull request was exported from Phabricator. Differential Revision: D48813697
This pull request was exported from Phabricator. Differential Revision: D48813697
This pull request was exported from Phabricator. Differential Revision: D48813697
This pull request has been merged in facebookresearch/d2go@477629d0f8e019ed631ee9d9f65202da9fdd7670.
Summary: Previous FSDP EMA checkpointing logic directly handles
EMAState
: it manually callsFSDP.summon_full_params()
to gather the full model params, and reconstruct/load anEMAState
for checkpointing. This logic has two drawbacks:FSDP.summon_full_params()
gathers all model weights at the same time, which could cause OOM issues if the model can't fit into a single GPU. This is quite common for FSDP workloads.EMAState
is error-prone. EMA state dict has different semantics and behaviors thanmodel.state_dict()
. However, users often expect it to function seamlessly like the model state dictThis diff modifies the save/load logic of EMA to directly use
model.state_dict()
to solve the above 2 painpointsDifferential Revision: D48813697