xrsrke / pipegoose

Large scale 4D parallelism pre-training for 🤗 transformers in Mixture of Experts *(still work in progress)*
MIT License
76 stars 17 forks source link

[Bug Fix] Balance transformer blocks across shards #41

Closed abourramouss closed 9 months ago

abourramouss commented 9 months ago

As we were discussing, the current implementation works like this:

  1. It first gives an equal number of parameters to each shard.
  2. If a transformer block is going to be split across diferent shards, prevent it, and make the current transformer block part of the current shard.

This way, we can guarantee that each shard/partition will get an equal amount of transformer blocks.

But there is an edge case, where if we specify that we want 5 shards and we have 6 transformer blocks in the model, In that case:

Shard 1 to 3 get 2 transformer blocks each, shard 4 gets the final layers and shard 5 doesn't get nothing.

To prevent this, if balance is not really important, we could shard based on transformer blocks, so if 5 shards were specified, shard 1 would get 2 transformer blocks and shard 2 to 5 would get 1 transformer block.