Open casper-hansen opened 11 months ago
Example of parallelization. I got a 2.88% speedup but that may be due to randomness. For better parallelization and to avoid CPU synchronization:
x_indices = [x[flat_expert_indices == i] for i in range(len(self.experts))]
futures = [torch.jit.fork(expert, x_i) for expert, x_i in zip(self.experts, x_indices)]
outputs = torch.futures.wait_all(futures)
# Assign the outputs to y in the correct sequence
for i, output in enumerate(outputs):
y[flat_expert_indices == i] = output
One interesting part of MegaBlocks is the initialization of the Linear. They have an additional experts_per_rank
that can be used to implement a fully parallel execution of multiple experts at the same time.
https://github.com/stanford-futuredata/megablocks/blob/main/megablocks/layers/mlp.py#L80-L85
⚠️ Please check that this feature request hasn't been suggested before.
🔖 Feature description
The core parts of training a Mixture of Experts model is to not forget the parallelism we came from in dense models. The current implementation is a naive and simple one that does not make use of the SOTA methods present in MegaBlocks. MegaBlocks dMoEs use a reformulation of MoEs in terms of block-sparse operations, which allows us to avoid token dropping without sacrificing hardware efficiency.
More details can be found in the MegaBlocks paper.
✔️ Solution
We should look into using the dMoE (dropless-MoE) which is the efficient and core part of MegaBlocks.
I believe the correct solution is to do the following:
gate
(router) into a separate class using LearnedRoutersparse_permute_and_compute
❓ Alternatives
No response
📝 Additional Context
No response
Acknowledgements