Open platers opened 1 year ago
We haven't added support yet. Happy to review changes if someone is interesting in adding support!
With a recent PyTorch version, you can set FSDP(..., use_orig_params=True)
and it should work OOTB again.
Actually this only solves one part of the issue.
Inside the FSDP-sharded MuReadout
, you won't be able to access the original weight
torch.nn.Parameter
, and accordingly infshape
will not be set again because FSDP replaces the parameters.
The following workaround worked for me. At the bottom, I attached a patch that changes MuReadout
so that in addition to weight.infshape
, it optionally looks for weight_infshape
. Then it's only a matter of manually setting model.mu_readout_layer.weight_infshape = model.mu_readout_layer.weight.infshape
, which can probably be nicely automatized using torch.nn.Module.apply
.
diff --git a/mup/layer.py b/mup/layer.py
index 518a33b..8bb07af 100644
--- a/mup/layer.py
+++ b/mup/layer.py
@@ -25,12 +25,18 @@ class MuReadout(Linear):
super().reset_parameters()
def width_mult(self):
- assert hasattr(self.weight, 'infshape'), (
- 'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
- 'switch to distributed training with '
- 'torch.nn.parallel.DistributedDataParallel instead'
- )
- return self.weight.infshape.width_mult()
+ if not hasattr(self.weight, 'infshape'):
+ if not hasattr(self, 'weight_infshape'):
+ raise AssertionError(
+ 'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
+ 'switch to distributed training with '
+ 'torch.nn.parallel.DistributedDataParallel instead'
+ )
+ else:
+ width_mult = self.weight_infshape.width_mult()
+ else:
+ width_mult = self.weight.infshape.width_mult()
+ return width_mult
def _rescale_parameters(self):
'''Rescale parameters to convert SP initialization to μP initialization.
Hi, does mup support training with FSDP? I got a model training with DDP, but when switching to FSDP I get the following assertion.
assert hasattr(self.weight, 'infshape'), ( AssertionError: Please call set_base_shapes(...). If using torch.nn.DataParallel, switch to distributed training with torch.nn.parallel.DistributedDataParallel instead