Closed vighneshbirodkar closed 1 day ago
This is not ideal because we want the batchnorm parameters to be in float32 (because they are EMAed).
You can wrap BatchNorm
separately -- just call fully_shard(batch_norm_module, **kwargs)
where kwargs
has a different MixedPrecisionPolicy
as needed, e.g. https://github.com/pytorch/pytorch/blob/ea0f60ecfabe0501485015841c4176d5a09c8247/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py#L543-L549
Thanks. That works.
Can you help me understand what's going on under the hood. Why do we have to call shard on the constituent modules and also the outer module ?
When you call fully_shard(module)
, this makes module
into an FSDPModule
, which corresponds to one FSDPParamGroup
that gets communicated together (all-gather/reduce-scatter). All parameters in module.parameters()
except those already assigned to a nested FSDPSubmodule
are assigned to module
.
Each FSDPModule
must have the same config (i.e. FSDP kwargs), including mixed precision policy, since its assigned parameters communicate together. If you want BatchNorm
to have a separate mixed precision policy, then you should apply fully_shard
to it separately. Then, when you apply fully_shard
to a parent module, since BatchNorm
already became an FSDPModule
, its parameters will be excluded from the parent FSDPModule
.
Feel free to read https://github.com/pytorch/pytorch/issues/114299 for more details.
When I run batchnorm with FSDP2 I get
How do I support this ?
FSDP1 had an exception for batchnorm as noted here.
Is there a workaround that I can use ?