Summary:
EMA can be configured to exclude frozen (requires_grad=False) parameters and buffers, reducing memory use and checkpoint size.
However FULL_STATE_DICT FSDP + EMA checkpoints construct an inner EMAState after unsharding FSDP parameters. This inner EMAState uses default include_frozen and include_buffers settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings.
Propagate include_frozen and include_buffers settings to the inner EMAState when gathering FULL_STATE_DICT FSDP EMA state.
This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate requires_grad across parameter sharding/unsharding: https://github.com/pytorch/pytorch/pull/109892.
Summary: EMA can be configured to exclude frozen (
requires_grad=False
) parameters and buffers, reducing memory use and checkpoint size.However
FULL_STATE_DICT
FSDP + EMA checkpoints construct an innerEMAState
after unsharding FSDP parameters. This innerEMAState
uses defaultinclude_frozen
andinclude_buffers
settings, resulting in checkpoints containing frozen parameters and buffers regardless of settings.Propagate
include_frozen
andinclude_buffers
settings to the innerEMAState
when gatheringFULL_STATE_DICT
FSDP EMA state.This change only affects frozen parameters with a parallel fix to PyTorch FSDP to propagate
requires_grad
across parameter sharding/unsharding: https://github.com/pytorch/pytorch/pull/109892.Differential Revision: D49517178