pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.58k stars 508 forks source link

Support Mixtral-8x7B #71

Closed yanboliang closed 7 months ago

yanboliang commented 9 months ago

This is based on #57. Please checkout https://github.com/yanboliang/gpt-fast/tree/mixtral-moe to try this.

Performance numbers (tokens/second):

|                  |   1 GPU |    2 GPU  |    8 GPU    |
|------------------|---------|-----------|-------------|
|baseline(bfloat16)|    OOM  |    78.75  |   203.69    |
|        int8      |   56.04 |    99.91  |   218.48    |

Note: Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology.

How to reproduce it:

export MODEL_REPO=mistralai/Mixtral-8x7B-v0.1
# Download model weights
python scripts/download.py --repo_id $MODEL_REPO
# Convert to gpt-fast supported format
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
# Generate int8 quantization model weights
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8
# Test tp=8
ENABLE_INTRA_NODE_COMM=1 torchrun --standalone --nproc_per_node=8 generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model.pth
# Test single GPU + int8 model
python generate.py --compile --compile_prefill --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth
raghukiran1224 commented 9 months ago

does the model compile without any graph breaks?

yanboliang commented 9 months ago

@raghukiran1224 Yes, no graph break!

chauhang commented 7 months ago

@yanboliang Great to see this PR, what is the work remaining for merging? It will help to also update the main Readme Benchmarks to include the model.

yanboliang commented 7 months ago

@chauhang I think we need to figure out a structure of how to put this under gpt-fast, probably we need a separate folder. No other blockers, so I'll prioritize this work and hopefully we can merge it in a few days.

yanboliang commented 7 months ago

closing this as it has been merged at https://github.com/pytorch-labs/gpt-fast/pull/105

guangy10 commented 7 months ago

@yanboliang It doesn't seem like python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8 will quantize all weights. I end up getting mismatched dtype error when lowering this model to ExecuTorch. After looking into the model_int8.pth, I noticed that there are still weights in bfloat16. Is it expected?

yanboliang commented 7 months ago

@guangy10 Yes, it's expected! We don't quantize gate networks to ensure accuracy as they are used to choose experts. https://github.com/pytorch-labs/gpt-fast/blob/1c23b94fcaf3a59bc21dd6fa4791ddae9aa63f05/mixtral-moe/quantize.py#L56-L57