mlfoundations / open_lm

A repository for research on medium sized language models.
MIT License
320 stars 41 forks source link

MoE performs worse than equivalent dense model? #253

Closed Muennighoff closed 1 month ago

Muennighoff commented 1 month ago

Afaict for the numbers reported here https://github.com/mlfoundations/open_lm/pull/115#issuecomment-1858719588 the "1 expert" model is still an MoE, correct?

I also get the result that the 8-expert MoE is better than the 1-expert one, however, both are worse than a dense model. In the below graph OpenLM-41M is a 41M dense model and the above two are 8-expert & 1-expert models with 41M active parameters.

Screenshot 2024-04-17 at 2 43 18 PM

I would expect the 1-expert to roughly match the dense one & the 8-expert to be better than both but maybe I am missing something? @kernelmachine @sagadre

(My setup follows the main README & https://github.com/mlfoundations/open_lm/blob/main/MOE.md)

achalddave commented 1 month ago

Great catch, thanks @Muennighoff! I think this is because the MoE defaults from megablocks differ from our default dense model, in at least two ways: their ffn uses Gelu without gating (w2 @ gelu(w1 @ x)) while ours uses swiglu (w3 @ (silu(w2 @ x) @ (w1 @ x))), and we use a different parameter init function.

I'm not sure what the easiest way to reconcile these is. Probably we want a custom version of the MLP class here https://github.com/databricks/megablocks/blob/2724ff6775ee7e2a41001a7979c0ec84c417cd84/megablocks/layers/mlp.py#L81-L137 that implements swiglu and our init function.

Muennighoff commented 1 month ago

I trained an OpenLM model with the Gelu & tanh approximate used in megablocks and regular normal init by adding the below two in model.py:

        elif args.ffn_type == "gelutanh":
            self.hidden_dim = args.dim * 4
            self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False)
            self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False)
            self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="tanh"), self._ff_w2)    

&

        elif self._ffn_type == "gelutanh":
            torch.nn.init.normal_(self._ff_w1.weight, mean=0.0, std=0.02)
            torch.nn.init.normal_(self._ff_w2.weight, mean=0.0, std=0.02)

It does indeed perform slightly worse than the regular 41M but still far better than the MoE variants. (For the 8 expert one params increased from 69M ot 97M as I increased MoE frequency to every layer instead of every 2nd). Do you have any more ideas where this could come from?

Screenshot 2024-04-18 at 11 49 13 AM
Muennighoff commented 1 month ago

For reference this is the code I am running:

No MoE:

torchrun --nproc-per-node 8 -m open_lm.main \
--train-data "/data/niklas/openlm/preproc/2048-v1/0/shard-{0000000..0000099}.tar" \
--train-num-samples 10000000000 \
--precision amp_bfloat16 \
--global-batch-size 64 \
--accum-freq 4 \
--log-every-n-steps 20 \
--grad-clip-norm 1 \
--lr 5e-4 \
--warmup 200 \
--model open_lm_41m \
--wd 0.1 \
--beta2 0.95 \
--epochs 50 \
--report-to wandb \
--wandb-project-name olmoe \
--name test$RANDOM \
--logs /data/niklas/openlm/moe \
--resume latest \
--seed 124 \
--data-key 'txt' \
--fsdp \
--fsdp-amp \
--model-norm gain_only_layer_norm \
--lr-scheduler cosine \
--lr-cooldown-end 0.00001 \
--ffn-type gelutanh

MoE w/ 8 experts:

torchrun --nproc-per-node 8 -m open_lm.main \
--train-data "/data/niklas/openlm/preproc/2048-v1/0/shard-{0000000..0000099}.tar" \
--train-num-samples 10000000000 \
--precision amp_bfloat16 \
--global-batch-size 64 \
--accum-freq 4 \
--log-every-n-steps 20 \
--grad-clip-norm 1 \
--lr 5e-4 \
--warmup 200 \
--model open_lm_41m \
--wd 0.1 \
--beta2 0.95 \
--epochs 50 \
--report-to wandb \
--moe-freq 1 \
--moe-num-experts 8 \
--moe-top-k 1 \
--moe-capacity-factor 1.25 \
--moe-loss-weight 0.1 \
--wandb-project-name olmoe \
--name test$RANDOM \
--logs /data/niklas/openlm/moe \
--resume latest \
--seed 124 \
--data-key 'txt' \
--fsdp \
--fsdp-amp \
--model-norm gain_only_layer_norm \
--lr-scheduler cosine \
--lr-cooldown-end 0.00001