pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.36k stars 485 forks source link

Mistral support #38

Closed Nikita-Sherstnev closed 4 months ago

Nikita-Sherstnev commented 7 months ago

Would it be hard to adapt this code for Mistral? I tried open orca version and set vocab_size in config to 32002. But shapes did not match:

File "/experiments/dev/nsherstnev/gpt-fast/scripts/convert_hf_checkpoint.py", line 61, in permute
    w.view(n_head, 2, config.head_dim // 2, dim)
RuntimeError: shape '[32, 2, 64, 4096]' is invalid for input of size 4194304
bwasti commented 7 months ago

you'll need to change some more configuration params (e.g. n_local_heads should be 8)

I'd copy them from here https://huggingface.co/docs/transformers/main/model_doc/mistral#transformers.MistralConfig

Artyom17 commented 4 months ago

Done in #116 The issue can be closed now.