pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.59k stars 204 forks source link

Batchnorm support with FSDP2 #674

Closed vighneshbirodkar closed 1 day ago

vighneshbirodkar commented 3 days ago

When I run batchnorm with FSDP2 I get

RuntimeError: Expected running_mean to have type BFloat16 but got Float                                                                                         

How do I support this ?

FSDP1 had an exception for batchnorm as noted here.

Is there a workaround that I can use ?

awgu commented 3 days ago

See here: https://github.com/pytorch/torchtitan/issues/671#issuecomment-2457398462

vighneshbirodkar commented 3 days ago

This is not ideal because we want the batchnorm parameters to be in float32 (because they are EMAed).

awgu commented 3 days ago

You can wrap BatchNorm separately -- just call fully_shard(batch_norm_module, **kwargs) where kwargs has a different MixedPrecisionPolicy as needed, e.g. https://github.com/pytorch/pytorch/blob/ea0f60ecfabe0501485015841c4176d5a09c8247/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py#L543-L549

vighneshbirodkar commented 1 day ago

Thanks. That works.

Can you help me understand what's going on under the hood. Why do we have to call shard on the constituent modules and also the outer module ?

awgu commented 1 day ago

When you call fully_shard(module), this makes module into an FSDPModule, which corresponds to one FSDPParamGroup that gets communicated together (all-gather/reduce-scatter). All parameters in module.parameters() except those already assigned to a nested FSDPSubmodule are assigned to module.

Each FSDPModule must have the same config (i.e. FSDP kwargs), including mixed precision policy, since its assigned parameters communicate together. If you want BatchNorm to have a separate mixed precision policy, then you should apply fully_shard to it separately. Then, when you apply fully_shard to a parent module, since BatchNorm already became an FSDPModule, its parameters will be excluded from the parent FSDPModule.

Feel free to read https://github.com/pytorch/pytorch/issues/114299 for more details.