mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.99k stars 525 forks source link

Added torch_dmoe defaults, bug fixes for 2D inputs #1210

Closed snarayan21 closed 4 months ago

snarayan21 commented 4 months ago

Users are hitting issues where changing their ffn_type from mb_dmoe to torch_dmoe results in errors with their config. This is because torch_dmoe does not have any defaults set. This PR ensures that torch_dmoe has the same defaults as mb_dmoe, according to the Arguments dataclass in the Megablocks repo. Fixed typing in some places as well. Added a test to make sure that mb_dmoe and torch_dmoe are the same in fwd and bwd for these defauly values.

There was also a bug for some inputs (top k = 1) where the outputs of the max function were 1 dimensional instead of 2. This bug has also been addressed.

There's a separate bug with Megablocks where, if both the hidden_size and ffn_hidden_size are both 128, there's an unintended bug -- this is why hidden_size has been bumped to 256. With values of hidden_size and ffn_hidden_size not divisible by 128, there's an unhelpful error thrown. Will open a separate PR on megablocks side to address this.

mvpatel2000 commented 4 months ago

LGTM but want @bigning 's signoff