pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.34k stars 484 forks source link

[example] Added (hacky) Grok1 support #171

Open Chillee opened 1 month ago

Chillee commented 1 month ago

Downloading from https://huggingface.co/hpcai-tech/grok-1

git clone --branch grok1 git@github.com:pytorch-labs/gpt-fast.git && cd gpt-fast/mixtral-moe
export MODEL_REPO=hpcai-tech/grok-1
python scripts/download.py --repo_id $MODEL_REPO
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int8
TOKENIZERS_PARALLELISM=false ENABLE_INTRA_NODE_COMM=1 time torchrun --standalone --nproc_per_node=8 generate.py  --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth   --compile --compile_prefill

Run on 8xA100 80GB

Time for inference 5: 2.73 sec total, 73.15 tokens/sec
Bandwidth achieved: 3057.61 GB/s
Average tokens/sec: 73.04

ms per output token: 13.67ms
merrymercy commented 1 month ago

This is awesome! I found one small bug:

Grok uses gelu for the MLP block https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L374 But mixtral uses silu https://github.com/pytorch-labs/gpt-fast/blob/de06b53a4f95c72cd3abd0a8e9fa2d6913676c1a/mixtral-moe/model.py#L214

You should replace it with gelu. Otherwise, the model can generate meaningful text but its performance is significantly degraded.

Chillee commented 3 weeks ago

@merrymercy ah that would explain my results haha. Thanks!