facebookresearch / d2go

D2Go is a toolkit for efficient deep learning
Apache License 2.0
826 stars 197 forks source link

Propagate include_frozen/buffers to EMAState in FSDP FULL_STATE_DICT checkpoints #620

Closed edpizzi closed 9 months ago

edpizzi commented 9 months ago

Summary: EMA can be configured to exclude frozen (requires_grad=False) parameters and buffers, reducing memory use and checkpoint size.

However FULL_STATE_DICT FSDP + EMA checkpoints construct an inner EMAState after unsharding FSDP parameters. This inner EMAState uses default include_frozen and include_buffers settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings.

Propagate include_frozen and include_buffers settings to the inner EMAState when gathering FULL_STATE_DICT FSDP EMA state.

This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate requires_grad across parameter sharding/unsharding: https://github.com/pytorch/pytorch/pull/109892.

Differential Revision: D49517178

facebook-github-bot commented 9 months ago

This pull request was exported from Phabricator. Differential Revision: D49517178

facebook-github-bot commented 9 months ago

This pull request has been merged in facebookresearch/d2go@206a05c69f32e8000963c96a495eb7558c375ae9.