facebookresearch / d2go

D2Go is a toolkit for efficient deep learning
Apache License 2.0
838 stars 201 forks source link

Make EMA checkpointing with FSDP more robust #615

Closed YanjunChen329 closed 1 year ago

YanjunChen329 commented 1 year ago

Summary: Previous FSDP EMA checkpointing logic directly handles EMAState: it manually calls FSDP.summon_full_params() to gather the full model params, and reconstruct/load an EMAState for checkpointing. This logic has two drawbacks:

  1. FSDP.summon_full_params() gathers all model weights at the same time, which could cause OOM issues if the model can't fit into a single GPU. This is quite common for FSDP workloads.
  2. Directly saving and loading EMAState is error-prone. EMA state dict has different semantics and behaviors than model.state_dict(). However, users often expect it to function seamlessly like the model state dict

This diff modifies the save/load logic of EMA to directly use model.state_dict() to solve the above 2 painpoints

Differential Revision: D48813697

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D48813697

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D48813697

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D48813697

facebook-github-bot commented 1 year ago

This pull request has been merged in facebookresearch/d2go@477629d0f8e019ed631ee9d9f65202da9fdd7670.