facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.17k stars 279 forks source link

[FSDP] Flatten parameters by group #697

Open QuentinDuval opened 3 years ago

QuentinDuval commented 3 years ago

🚀 Feature

FSDP to offer the possibility to flatten parameters by group, for instance, to flatten all biases separately from the other weights.

Motivation

Following issue https://github.com/facebookresearch/fairscale/issues/644 and the attempt in solving it in PR https://github.com/facebookresearch/fairscale/pull/692, providing a view of the parameters in the view is too low level and does not offer the best user experience as there are plenty of ways to trip and fall (see PR for the limitations: https://github.com/facebookresearch/fairscale/pull/692).

Following a discussion with @min-xu-ai, the best is to go back at the uses cases motivating https://github.com/facebookresearch/fairscale/issues/644:

  1. be able to easily compute the weight and gradient norms when flatten_parameters=True needed for LARC like optimisers
  2. be able to have separate LR, regularisation for each parameters (for instance, regularise only bias and not weight or the other way around)

This issue proposes to solve item 2.

Workarounds

There is no workaround for the moment when flatten_parameters=True: having different wrappers for different modules does not offer the granularity required to flatten biases and weights of the nn.Linear layer separately for instance.

Interested parties

CC: @min-xu-ai @myleott @prigoyal

min-xu-ai commented 3 years ago

Do you need this first or #696 first?

QuentinDuval commented 3 years ago

Do you need this first or #696 first?

Since this has a potential impact on evaluation which comes next, I think this issue is more important.

jramapuram commented 2 years ago

Any update on this? LAMB/LARC/LARS/AGC all rely on being able to set layer-wise scaling.

min-xu-ai commented 2 years ago

Thanks for pinging. I haven't got time working on this lately. I will try to find some time again for it. The underlying code for multiple flatten groups is already there. But the API to expose it in terms of reflattening isn't there yet.

jramapuram commented 2 years ago

@min-xu-ai : just want to add here that this feature is also necessary for varying weight decay on a different subset of parameter groups (important for transformer scaling for example).

min-xu-ai commented 2 years ago

@jramapuram, thanks for checking. I haven't had chance working on this further. But @tmarkstrum had a private branch that can support weights/bias decay differently but it is not general enough. CC @anupambhatnagar in case he has time looking at this.

anupambhatnagar commented 2 years ago

Quick update - this work is not planned for the near future.