facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
29.75k stars 6.3k forks source link

Inference on MoE models #5497

Open meenakshi-mittal opened 1 month ago

meenakshi-mittal commented 1 month ago

The inference command provided for MoE models gives errors on both the provided pre-trained MoE models and on the ones I have trained myself.

This is the command that I am trying to use:

DATA_PATH=/path/to/data-bin MODEL_PATH=/path/to/model.pt python -m fairseq_cli.eval_lm \ $DATA_DIR --path $MODEL_PATH \ --gen-subset valid \ --sample-break-mode none \ --tokens-per-sample 2048 \ --batch-size 1 \ --fp16 \ --output-word-probs \ --is-moe \ --distributed-world-size 8 \ --model-overrides "{'world_size': 8, 'moe_eval_capacity_token_fraction': 0.05}"

I downloaded the pre-trained moe_15b model from the moe_lm README onto my machine, and it unzips into a directory that looks like this:

en_moe_lm_15b: —model-rank-0.pt —model-rank-1.pt ... —model-rank-63.pt —model-shared.pt

I try running the given command, setting MODEL_PATH=/path/to/en_moe_lm_15b/model.pt and DATA_PATH=/path/to/data-bin/wikitext-103. I get the following error:

Traceback (most recent call last): File "/data/meenakshi/miniconda3/envs/fairseq_moe/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, *args) File "/data/meenakshi/MoE/fairseq/fairseq/distributed/utils.py", line 335, in distributed_main main(cfg, **kwargs) File "/data/meenakshi/MoE/fairseq/fairseq_cli/eval_lm.py", line 384, in main models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( File "/data/meenakshi/MoE/fairseq/fairseq/checkpoint_utils.py", line 478, in load_model_ensemble_and_task model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) File "/data/meenakshi/MoE/fairseq/fairseq/models/fairseq_model.py", line 126, in load_state_dict return super().load_state_dict(new_state_dict, strict) File "/data/meenakshi/miniconda3/envs/fairseq_moe/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for TransformerLanguageModel: size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([51200, 768]) from checkpoint, the shape in current model is torch.Size([267744, 768]). size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([51200, 768]) from checkpoint, the shape in current model is torch.Size([267744, 768]).

I understand that this is due to the dict.txt of the wikitext-103 dataset being a different size than the one used to train the moe-15b model, but how do I fix this? I cannot find any information about the dict.txt of the moe-15b model.

——————

I have also tried to train my own moe models using the following command:

NUM_EXPERTS=8 TOKENS_PER_SAMPLE=1024

fairseq-train --task language_modeling \ data-bin/wikitext-103 \ --save-dir checkpoints/moe_wikitext-103 \ --tokens-per-sample $TOKENS_PER_SAMPLE \ --ddp-backend fully_sharded --memory-efficient-fp16 --checkpoint-activations \ --arch transformer_lm_gpt --share-decoder-input-output-embed \ --decoder-layers 24 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 \ --decoder-attention-heads 16 \ --moe-expert-count $NUM_EXPERTS --moe-freq 2 \ --moe-gating-use-fp32 --moe-second-expert-policy all \ --moe-normalize-expert-grad sqrt_world_size \ --moe-eval-capacity-token-fraction -1.0 \ --max-sentences-valid 1 --num-workers-valid 0 \ --criterion moe_cross_entropy --moe-gate-loss-wt 0.01 --moe-gate-loss-combine-method sum \ --optimizer adam --fp16 --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr 0.0005 --warmup-updates 750 \ --dropout 0.2 --attention-dropout 0.2 \ --batch-size 2 --update-freq 2 \ --max-update 250 --disable-validation \ --log-format json --log-interval 10

And after training I get a model directory that looks like this:

moe_wikitext-103: —checkpoint_last-rank-0-shard0.pt —checkpoint_last-rank-1-shard1.pt ... —checkpoint_last-rank-7-shard7.pt —checkpoint_last-shared-shard0.pt ... —checkpoint_last-shared-shard7.pt

Running inference on this model using a similar command as above originally results in errors like this:

Model file not found: checkpoints/moe_wikitext-103/checkpoint_last-rank-7.pt

So I edited the eval_lm.py file to add "-shard{rank}" to the end of the files. After trying that I get this error:

Traceback (most recent call last): File "/data/meenakshi/miniconda3/envs/fairseq_moe/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap fn(i, *args) File "/data/meenakshi/MoE/fairseq/fairseq/distributed/utils.py", line 335, in distributed_main main(cfg, **kwargs) File "/data/meenakshi/MoE/fairseq/fairseq_cli/eval_lm.py", line 384, in main models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( File "/data/meenakshi/MoE/fairseq/fairseq/checkpoint_utils.py", line 478, in load_model_ensemble_and_task model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) File "/data/meenakshi/MoE/fairseq/fairseq/models/fairseq_model.py", line 126, in load_state_dict return super().load_state_dict(new_state_dict, strict) File "/data/meenakshi/miniconda3/envs/fairseq_moe/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for TransformerLanguageModel: size mismatch for decoder.embed_tokens.weight: copying a param with shape torch.Size([50000, 1024]) from checkpoint, the shape in current model is torch.Size([267744, 1024]). size mismatch for decoder.output_projection.weight: copying a param with shape torch.Size([50000, 1024]) from checkpoint, the shape in current model is torch.Size([267744, 1024]).

This is similar to the previous error but it makes no sense to me as I trained the model on the same dataset that I am trying to evaluate it on.

Environment details:

sunrainyg commented 3 weeks ago

I met the same issue, any solutions?