huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.12k stars 105 forks source link

moe in src and load balancing losses #159

Open haeggee opened 4 months ago

haeggee commented 4 months ago

Hey,

First of all, thanks for the really nice codebase! Just wanted to add a discussion on the MoE implementation as well as adding load balancing losses.

I see that you deliberately put MoEs as an example instead of in src. Obviously, it's very simple to include the features in the main modeling code and I think it's worth it to do so. In particular, the example setup requires to start the runs from within the examples folder (or wherever the MoE code is); we have a setup in our team where you start training runs outside of nanotron and thus need imports via the pip installed nanotron.

I've implemented the changes for that here: https://github.com/swiss-ai/nanotron/tree/moe

There are also some other minor fixes, for example a correct cost estimation for pipeline splitting, actually using the arguments for activation function, or the correct SwiGLU for expert parallel (last time I checked, your example code used only 2 instead of the 3 weight matrices when experts per rank = 1).

More importantly, I added load balancing losses for expert balance. In my experience, they can make an important difference especially at larger scale, for instance for GPU utilization or training stability (see e.g. the "[...] a lot of weird errors turn out to be symptoms of expert load imbalance", link). How I implemented the losses is maybe suboptimal -- it's very similar to how Megablocks did it originally, where the balancing losses are local to each pipeline rank. They are still added to be tracked for the backward pass in the pipeline engine (commit here). But, when having multiple ranks in pipeline parallel, the logging at the last rank does not see previous ranks' losses. This means e.g. wandb logs show lower loss values. I guess the only way to log correctly would be to pass the losses through the network (just like the inputs via TensorPointers etc.).

I would be very happy to hear your thoughts and input on this, in particular the load balancing implementation. If desired, I could also open a PR and we continue discussing there :)

sjelassi commented 3 months ago

I agree with the points raised by @haeggee . Are there any plans to integrate his commits to the codebase?

haeggee commented 3 months ago

Thanks a lot for your comment! I'm currently working on a PR #192 for this :) I've also fixed the issue of correct logging there. Waiting for more input