pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.34k stars 484 forks source link

llama3 8B support, tiktoken tokenizer #158

Closed Artyom17 closed 2 months ago

Artyom17 commented 2 months ago

Surprisingly, Llama 3 switched to Tiktoken tokenizer from SentencePiece. This PR implements wrappers for both - Tiktoken and SentencePiece tokenizers, as well as adding params for Llama-3-8B and -70B models.

As to scripts/convert_hf_checkpoint.py. Llama3 on HF doesn't have those pytorch_model-xxxx-xxxx.bin files anymore; instead, the Pytorch model is located in 'original' sub-dir with different names pattern ('consolidated.XX.pth'). For 8B models it is a single file that just needs to be copied as model.pth into the parent directory, no need to mess with names of the weights. The original/tokenizer.model also just needs to be copied into its parent directory (and Tiktoken tokenizer must be used instead of the SentencePieceProcessor).

As to 70B model - it is not covered by this PR since it is not clear to me how to handle multiple consolidate.XX.pth files with THE SAME weight names in each (unlike how it is with pytorch_model_XXXX-of-XXXXX.bin files, where each .bin contains a certain subset of the weights).

Artyom17 commented 2 months ago

Note, while I was testing int4 quantization with the llama3-8B model I found this bug: #159

Muhtasham commented 2 months ago

great job! how much tokens per second are getting, mind sharing some stats @Artyom17

Artyom17 commented 2 months ago

great job! how much tokens per second are getting, mind sharing some stats @Artyom17

it is a bit slower than Mistral or llama 2. I got 165 t/s on H100 with llama 3, while llama 2 gave me 185 t/s, Mistral-7B - 175 t/s

Artyom17 commented 2 months ago

Ping?

Chillee commented 2 months ago

cc: @yanboliang

yanboliang commented 2 months ago

I'll review it tomorrow!

Artyom17 commented 2 months ago

Looks good to me! Can you update the corresponding benchmark number in README? Thank you!

Unfortunately, all the benchmarks in the README.md are made on 8xA100, but I have access only to 8xH100.

nivibilla commented 2 months ago

@Artyom17 I was looking into how to deal with the 70b. And I found this old script.

https://github.com/tloen/llama-int8/blob/main/example.py Particularly the load function

Is this useful?

Chillee commented 2 months ago

@yanboliang Could we run the A100 numbers and add it to the README?

yanboliang commented 2 months ago

Yea, I can run & update the A100 numbers. Probably we can do a small update on README to split benchmarks into A100/H100/AMD.

yanboliang commented 2 months ago

Perf numbers on A100: https://github.com/pytorch-labs/gpt-fast/pull/166

nivibilla commented 2 months ago

Hey @Artyom17 ive generalised the support for llama 3, im able to convert both llama 3 8b and 70b. Pls see the pr to your fork here . By pre converting the safetensors format to the PyTorch bin format. The hf conversion script works as is. And all I needed to add were the model configs in the model.py file

nivibilla commented 2 months ago

Some performance numbers on 8xA10 python /generate.py --compile --checkpoint_path ./llama-3-8b-instruct-hf-pt/model.pth

# 70b TP8
Average tokens/sec: 21.79
Memory used: 21.66 GB/GPU

# 8b TP8
Average tokens/sec: 112.74
Memory used: 4.19 GB/GPU

# 8b NO_TP
Average tokens/sec: 34.06
Memory used: 16.43 GB/GPU
Artyom17 commented 2 months ago

Hey @Artyom17 ive generalised the support for llama 3, im able to convert both llama 3 8b and 70b. Pls see the pr to your fork here . By pre converting the safetensors format to the PyTorch bin format. The hf conversion script works as is. And all I needed to add were the model configs in the model.py file

Nice! Looking at it, thanks a lot!

Artyom17 commented 2 months ago

Hey @Artyom17 ive generalised the support for llama 3, im able to convert both llama 3 8b and 70b. Pls see the pr to your fork here . By pre converting the safetensors format to the PyTorch bin format. The hf conversion script works as is. And all I needed to add were the model configs in the model.py file

Yeah, I've tested it, it works (with some misspelling caveats I mentioned in the PR). I am not sure we can integrate these changes atm, since it will create dependency on a third-party models (the eastwind/* ones), but gpt-fast owners may correct me if I am wrong. The best outcome would be if HF adopts your conversion and releases those .bin files properly. Alternatively, gpt-fast users should be able to convert .safetensors to .bin (I am not super familiar with this process, how hard is it?).

The right flow of events IMO should be as follows:

  1. This PR gets landed, really hope it happens soon (@Chillee )
  2. Either HF adopts the bin files or there is a way to convert .safetensors to .bin for gpt-fast users;
  3. You, @nivibilla create a PR here, in gpt-fast repo that adds proper llama3-70b support (and unifying 8b support, like you did in that other PR).
nivibilla commented 2 months ago

@Artyom17 all I did was load the model into memory and do save_pretrained(dir, safe_serialization=False).

This is doable but you would need enough memory to load the HF model and then save it into the "unsafe" version. If one is running these models then I guess it's safe to assume they have the requirements to do this themselves, or have code to do it in here. But I don't want to ruin this repo and add a dependency to transformers lol.

Imo a better solution is to modify the existing code to work with safetensors but idk how difficult that is.

Artyom17 commented 2 months ago

@Artyom17 all I did was load the model into memory and do save_pretrained(dir, safe_serialization=False).

This is doable but you would need enough memory to load the HF model and then save it into the "unsafe" version. If one is running these models then I guess it's safe to assume they have the requirements to do this themselves, or have code to do it in here. But I don't want to ruin this repo and add a dependency to transformers lol.

Imo a better solution is to modify the existing code to work with safetensors but idk how difficult that is.

Well, you can't use 70B model anyway, unless you have a beefy machine with A100/H100 80Gb. Adding safetensor models support into gpt-fast sounds doable too, found this article: https://medium.com/@mandalsouvik/safetensors-a-simple-and-safe-way-to-store-and-distribute-tensors-d9ba1931ba04

danieltmeta commented 2 months ago

@Artyom17 all I did was load the model into memory and do save_pretrained(dir, safe_serialization=False).

This is doable but you would need enough memory to load the HF model and then save it into the "unsafe" version. If one is running these models then I guess it's safe to assume they have the requirements to do this themselves, or have code to do it in here. But I don't want to ruin this repo and add a dependency to transformers lol.

Imo a better solution is to modify the existing code to work with safetensors but idk how difficult that is.

Hello, I managed to brute force my way to convert from .safetensors to .bin for Meta-Llama-3-70B by loading the model and using save_pretrained(dir, safe_serialization=False) with the following code:

`from transformers import AutoModel

model = AutoModel.from_pretrained(checkpoint_dir) model.save_pretrained(checkpoint_dir, safe_serialization=False)`

I wonder if I may ask a few questions:

  1. Am I using the right library to implement save_pretrained for 70B? Or it doesn't matter?
  2. I ended up generating 61 .bin files, while this other person generated 30? https://huggingface.co/eastwind/llama-3-70b-instruct-hf-pt/tree/main
  3. For method, I am missing 'lm_head.weight', was wondering if it was related to the previous two questions?

Thank you for your efforts!

nivibilla commented 2 months ago

Hey @danieltmeta that's exactly what I did. Eastwind is my huggingface account btw 🤣.

Not sure why you are getting different results. We should get the exact same weights. Are you on the latest transformers and safetensors libraries?

Also if you want to use that, pls see use my version of the llama 3 integration in this PR. #169

jerrymannil commented 2 months ago

@Artyom17 @nivibilla

I think supporting safetensor is the simplest way to do this (minimal code changes) I had actually tried this with LL3 8B model and seems to work.

  1. Detect safetensor or pytorch bin based on "index.json" file
  2. Use safetensors.torch.load_file instead of torch.load() here
nivibilla commented 2 months ago

@jerrymannil thank you! I will make this change and test it out on my other PR tomorrow.