microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.37k stars 94 forks source link

FSDP support? #59

Open platers opened 1 year ago

platers commented 1 year ago

Hi, does mup support training with FSDP? I got a model training with DDP, but when switching to FSDP I get the following assertion.

assert hasattr(self.weight, 'infshape'), ( AssertionError: Please call set_base_shapes(...). If using torch.nn.DataParallel, switch to distributed training with torch.nn.parallel.DistributedDataParallel instead

edwardjhu commented 11 months ago

We haven't added support yet. Happy to review changes if someone is interesting in adding support!

janEbert commented 5 months ago

With a recent PyTorch version, you can set FSDP(..., use_orig_params=True) and it should work OOTB again.

janEbert commented 5 months ago

Actually this only solves one part of the issue. Inside the FSDP-sharded MuReadout, you won't be able to access the original weight torch.nn.Parameter, and accordingly infshape will not be set again because FSDP replaces the parameters.

The following workaround worked for me. At the bottom, I attached a patch that changes MuReadout so that in addition to weight.infshape, it optionally looks for weight_infshape. Then it's only a matter of manually setting model.mu_readout_layer.weight_infshape = model.mu_readout_layer.weight.infshape, which can probably be nicely automatized using torch.nn.Module.apply.

diff --git a/mup/layer.py b/mup/layer.py
index 518a33b..8bb07af 100644
--- a/mup/layer.py
+++ b/mup/layer.py
@@ -25,12 +25,18 @@ class MuReadout(Linear):
             super().reset_parameters()

     def width_mult(self):
-        assert hasattr(self.weight, 'infshape'), (
-            'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
-            'switch to distributed training with '
-            'torch.nn.parallel.DistributedDataParallel instead'
-        )
-        return self.weight.infshape.width_mult()
+        if not hasattr(self.weight, 'infshape'):
+            if not hasattr(self, 'weight_infshape'):
+                raise AssertionError(
+                    'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
+                    'switch to distributed training with '
+                    'torch.nn.parallel.DistributedDataParallel instead'
+                )
+            else:
+                width_mult = self.weight_infshape.width_mult()
+        else:
+            width_mult = self.weight.infshape.width_mult()
+        return width_mult

     def _rescale_parameters(self):
         '''Rescale parameters to convert SP initialization to μP initialization.