Closed epwalsh closed 3 months ago
Followup to #540. Fixes how we collect per-param optim metrics when using hybrid sharding. The process group we're using is the same process group that FSDP uses during hybrid sharding when reducing the grad norms, for example, so it should be the right one. See https://github.com/pytorch/pytorch/blob/cb17721899d4d6a55d66d4f7188e36c20a078231/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1149.
FSDP
Followup to #540. Fixes how we collect per-param optim metrics when using hybrid sharding. The process group we're using is the same process group that
FSDP
uses during hybrid sharding when reducing the grad norms, for example, so it should be the right one. See https://github.com/pytorch/pytorch/blob/cb17721899d4d6a55d66d4f7188e36c20a078231/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1149.