axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.47k stars 806 forks source link

Parallelize and optimize Mixtral MoE #930

Open casper-hansen opened 9 months ago

casper-hansen commented 9 months ago

⚠️ Please check that this feature request hasn't been suggested before.

🔖 Feature description

The core parts of training a Mixture of Experts model is to not forget the parallelism we came from in dense models. The current implementation is a naive and simple one that does not make use of the SOTA methods present in MegaBlocks. MegaBlocks dMoEs use a reformulation of MoEs in terms of block-sparse operations, which allows us to avoid token dropping without sacrificing hardware efficiency.

More details can be found in the MegaBlocks paper.

✔️ Solution

We should look into using the dMoE (dropless-MoE) which is the efficient and core part of MegaBlocks.

I believe the correct solution is to do the following:

❓ Alternatives

No response

📝 Additional Context

No response

Acknowledgements

casper-hansen commented 9 months ago

Example of parallelization. I got a 2.88% speedup but that may be due to randomness. For better parallelization and to avoid CPU synchronization:

x_indices = [x[flat_expert_indices == i] for i in range(len(self.experts))]
futures = [torch.jit.fork(expert, x_i) for expert, x_i in zip(self.experts, x_indices)]
outputs = torch.futures.wait_all(futures)

# Assign the outputs to y in the correct sequence
for i, output in enumerate(outputs):
    y[flat_expert_indices == i] = output
casper-hansen commented 9 months ago

One interesting part of MegaBlocks is the initialization of the Linear. They have an additional experts_per_rank that can be used to implement a fully parallel execution of multiple experts at the same time.

https://github.com/stanford-futuredata/megablocks/blob/main/megablocks/layers/mlp.py#L80-L85