Closed YanjunChen329 closed 1 year ago
This pull request was exported from Phabricator. Differential Revision: D46815031
This pull request was exported from Phabricator. Differential Revision: D46815031
This pull request has been merged in facebookresearch/d2go@5c23bee8c2190c6b99aba88bb7bc8814d3e51710.
Summary:
Problem:
d2go EMA uses
named_parameters()
to traverse model states and save EMA checkpoints, while usingstate_dict()
to save model checkpoints. This is a brittle practice becausenamed_parameters()
andstate_dict()
are calling two sets of python APIs and can return different things. In the case of Activation Checkpointing (AC), we don't want AC wrapper to affect checkpoint names. Thus,state_dict()
is overriden by Pytorch to remove prefix "_checkpoint_wrapped_module" from FQN. However,named_parameters()
does not have that support, so prefix still exists. In the event of us changing AC wrapping strategy (very common for optimization), we will not be able to load the previous EMA state back to the model. And the same problem also happened with FSDP.Short-term hack:
This diff adds a short term hack to manually remove the AC prefix in EMA. We can expand
IGNORED_FQN_PREFIX
to support more use cases.Differential Revision: D46815031