databricks / megablocks

Apache License 2.0
1.11k stars 154 forks source link

How to integrate to transformers-based mixtral #84

Open nxphi47 opened 6 months ago

nxphi47 commented 6 months ago

Hi, this is awesome work. I'm wondering if there is a minimal way to integrate megablocks into transformers codebase for the mixtral architecture?

Would simply replacing the MixtralSparseMoeBlock with dmoe.dMoE with proper configuration works?

# from transformers 

class MixtralDecoderLayer(nn.Module):
    def __init__(self, config: MixtralConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

        self.block_sparse_moe = MixtralSparseMoeBlock(config)
        ....

Thanks!

tgale96 commented 5 months ago

Hi! I believe that would work. In addition to configuring the expert count/top-k appropriately, you'll want to set moe_normalize_expert_weights to 1.0 to match their post-top-k expert weight normalization. You'll have to handle any differences in how the load balancing loss is computed/returned as well.

Please let us know if you encounter any issues and we'd be more than happy to help debug :)