Users are hitting issues where changing their ffn_type from mb_dmoe to torch_dmoe results in errors with their config. This is because torch_dmoe does not have any defaults set. This PR ensures that torch_dmoe has the same defaults as mb_dmoe, according to the Arguments dataclass in the Megablocks repo. Fixed typing in some places as well. Added a test to make sure that mb_dmoe and torch_dmoe are the same in fwd and bwd for these defauly values.
There was also a bug for some inputs (top k = 1) where the outputs of the max function were 1 dimensional instead of 2. This bug has also been addressed.
There's a separate bug with Megablocks where, if both the hidden_size and ffn_hidden_size are both 128, there's an unintended bug -- this is why hidden_size has been bumped to 256. With values of hidden_size and ffn_hidden_size not divisible by 128, there's an unhelpful error thrown. Will open a separate PR on megablocks side to address this.
Users are hitting issues where changing their
ffn_type
frommb_dmoe
totorch_dmoe
results in errors with their config. This is becausetorch_dmoe
does not have any defaults set. This PR ensures thattorch_dmoe
has the same defaults asmb_dmoe
, according to the Arguments dataclass in the Megablocks repo. Fixed typing in some places as well. Added a test to make sure thatmb_dmoe
andtorch_dmoe
are the same in fwd and bwd for these defauly values.There was also a bug for some inputs (top k = 1) where the outputs of the
max
function were 1 dimensional instead of 2. This bug has also been addressed.There's a separate bug with Megablocks where, if both the
hidden_size
andffn_hidden_size
are both 128, there's an unintended bug -- this is whyhidden_size
has been bumped to 256. With values ofhidden_size
andffn_hidden_size
not divisible by 128, there's an unhelpful error thrown. Will open a separate PR on megablocks side to address this.