pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
BSD 3-Clause "New" or "Revised" License
3.56k stars 291 forks source link

How to save a trained model so it can be loaded with HF `from_pretrained()`? #832

Open calmitchell617 opened 2 months ago

calmitchell617 commented 2 months ago

I'm finding this repo to be a user friendly, extensible, memory efficient solution for training/fine-tuning models. However, when it comes to inference, there is a usability gap that could be solved by converting the model into a format that can be loaded by HF's from_pretrained() function.

The specific thing I want to do is load a model fine-tuned with torchtune into a Gradio chatbot, complete with token streaming. I imagine many other downstream tasks would be made easier with this functionality as well.

Would it be reasonable to add the following options to the checkpointer?

If this seems like a valid addition, and isn't a huge lift, I would be happy to give it a try.

calmitchell617 commented 2 months ago

After some initial research, the functionality I'm imagining is already implemented for Llama 2 in the convert_llama_weights_to_hf() function in HF Transformers.

I opened an issue to add support for Llama 3 to that function.

I successfully converted a Llama 2 model saved with the meta checkpointer to a format that can be loaded with from_pretrained() in that linked issue.

calmitchell617 commented 2 months ago

Looks like someone at HF is already working on converting that script to support Llama 3.

kartikayk commented 2 months ago

@calmitchell617 thanks for opening up this issue! Actually, we do support HF formats directly in torchtune by using this function. More details here. I believe this should work OOTB for the safetensors files available in the llama3 repo, but I can confirm in a bit with you.

Generally, we've designed torchtune to be state-dict invariant. So the format of the checkpoint in the input is the format we write to. The reason was exactly what you mentioned above i.e. better interop with other tools in the ecosystem. let me know if this helps.

kartikayk commented 2 months ago

The point about lora adapters is a great one! I had an offline chat with @BenjaminBossan about this at some point. Let me follow up on this and see if we can build more interop here with peft.

calmitchell617 commented 2 months ago

@kartikayk, thanks for the quick follow up. A peft integration would be very handy, but the other thing is more pressing for me right now.

I did come across that function you mentioned, but am still not sure how to accomplish my use case. Maybe I am not understanding an existing way of doing things.

Instead of presecribing a solution, let me try to be really specific with my use case. Hopefully I'm just making things harder than they need to be, and you will have an easy solution. I want to:

  1. Download Llama 3 8B.
  2. Fine tune with torchtune.
  3. Write the resulting model in a format that can be read by from_pretrained().
  4. Run inference with the fine tuned model in a Gradio chat app with from_pretrained().

Two issues I see with step 3 of that list are:

Again, thank you for your fast responses. Hopefully I'm just not seeing an existing solution.

kartikayk commented 2 months ago

@calmitchell617 I just tried what I had mentioned and realized the folly of what I said above :) So I was under the impression that the "safetensors" files are compatible with from_pretrained, but thats not true like you said. I do think this should be easy to accomplish. Let me quickly try something and share a code snippet with you.

kartikayk commented 2 months ago

@calmitchell617 ok I think I can convert the weights around correctly using the following code:

import torch
from torchtune.models import convert_weights

sd = torch.load('Meta-Llama-3-8B/original/consolidated.00.pth', mmap=True, map_location='cpu')
sd_tune = convert_weights.meta_to_tune(sd)
sd_hf = convert_weights.tune_to_hf(sd_tune, num_heads=32, num_kv_heads=8)

But as I was looking into the HF code for this, I realized there might be additional piping needed here to actually get this up and running with from_pretrained. I don't understand that pipeline very well, but let me know if this helps you get started in the right direction, even if its not the full solution.

calmitchell617 commented 2 months ago

Great! I will play around with the functions you mentioned to see if it is possible. If not, I will keep an eye on Huggingface's PR to see if that functionality can be applied, even if it is just an example script provided in a doc somewhere.

You may already be aware that loading a model with from_pretrained() is a very popular way to load an NLP model for inference, so supporting that use case wouldn't be wasted.

kartikayk commented 2 months ago

@calmitchell617 this is great feedback! Let me take a closer look at this. Do you think we'll need to do something different than creating hf-format checkpoints (once these are available)? For llama2, I think this works OOTB. We did verify interop with llama.cpp for example. But let me take a closer look at the inference support within HF as well - again thanks so much for the feedback on this!

calmitchell617 commented 2 months ago

I do think an extra processing step is required. Here are exactly the steps I took to download and fine-tune Llama 2 with torchtune, then process it to load successfully with from_pretrained().

Download Llama 2 in a non HF format via torchtune cli:

tune download meta-llama/Llama-2-7b-chat --output-dir <checkpoint_path> --hf-token $HF_TOKEN

Fine tune Llama 2 with torchtune, again using the cli:

tune run \
    --nproc_per_node=4 \
    lora_finetune_distributed \
    --config llama2/7B_lora \
    batch_size=1 \
    seed=29 \
    tokenizer.path=<checkpoint_path> \
    checkpointer.checkpoint_dir=<checkpoint_path> \
    checkpointer.output_dir=<checkpoint_path> \
    dataset=torchtune.datasets.my_custom_dataset \
    checkpointer.checkpoint_files=['consolidated.00.pth'] \
    checkpointer=torchtune.utils.FullModelMetaCheckpointer \
    gradient_accumulation_steps=1 \
    lr_scheduler.num_warmup_steps=5 \
    enable_activation_checkpointing=False \
    dataset.max_rows=100 \
    epochs=1

Convert the fine-tuned model to a format that can be loaded with from_pretrained() with the convert_llama_weights_to_hf() function. You can simply copy and paste that function into a standalone script and call it.

Last, test the conversion by running a Gradio chatbot on the fine-tuned/converted model. It worked as expected.

As mentioned before, the convert_llama_weights_to_hf() function does not yet support Llama 3, but Huggingface are already working to add support.

A crucial thing to note is that I have not able to load a checkpoint saved with torchtune's hf checkpointer with from_pretrained().

calmitchell617 commented 2 months ago

So, the problem is already solved for Llama 2, but of course, everyone wants Llama 3 :-)

There may be some pre-release code in the Huggingface repo that adds Llama 3 support to convert_llama_weights_to_hf(). I will give that a try tomorrow and let you know how it goes here.

If that goes well, I will post a full reproducible example. From there, we can see if it's worth including the script as an example, baking into a helper function, or even into the checkpointer as an option.

kartikayk commented 2 months ago

This is awesome info! I'd love to discuss baking this into the checkpointer directly, since that was the intent of the HF checkpointer to begin with :)

calmitchell617 commented 2 months ago

Great. This issue is important to me so I will work on it tomorrow while keeping the goal of having it work OOTB with the checkpointer in mind.

Happy to discuss via email or video call anytime. Thanks again for your attention on this.

kartikayk commented 2 months ago

Actually, I'd love to do a quick call on this and figure it out! Mind sharing your email or pinging me on discord (@ KK on the discord channel) so we can set this up? I really appreciate all of the effort in figuring this out - really awesome!

kartikayk commented 2 months ago

@calmitchell617 I took a look at the code pointer above, and it seems like there's just a little bit of json wrangling needed to make this work. The model checkpoint itself doesn't need any changes. Let me know if you disagree?

calmitchell617 commented 2 months ago

The model checkpoint itself doesn't need any changes. Let me know if you disagree?

In addition to JSON wrangling, I believe HF's model conversion code is:

I tested the code in HF's PR to add Llama 3 support to their model conversion function. After a few small alterations, it worked fine. So, it seems like we at least have a blueprint to follow to include that functionality in this repo.

apthagowda97 commented 2 months ago

Noob Question:

For converting the meta_model_0.pt to GGUF. Do we need to convert to HF and then convert to GGUF (fp16) or can we do it directly from meta_model_0.pt?

calmitchell617 commented 2 months ago

@apthagowda97, you will most likely need to convert to HF first, as that is probably the format whatever tool you're using to convert is expecting.

kartikayk commented 2 months ago

So I do think llama.cpp's convert script supports the meta format (I've done this many times for llama2 for example). Here's the code: https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1364

The caveat is that they explicitly check for consolidated.00.pth and so @apthagowda97 you'll need to rename the checkpoint which isn't too bad.

The only question I'm not sure about is if they support tiktoken or not? You can give it a whirl. Just make sure you have the tokenizer model file in the same folder.

calmitchell617 commented 2 months ago

@kartikayk, following our discussion, here is a full reproducible series of steps that I took to download and convert Llama 3 to a HF format that can be read by from_pretrained(). Others may find this example useful as a temporary workaround, as well.

@apthagowda97, would you give it a try and let me know if it works for your use case?

Prerequisites:

Empty dir

mkdir ~/models

torchtune installed

tune -h

transformers checked out to correct branch (and installed)

cd /tmp
git clone https://github.com/huggingface/transformers.git
cd /tmp/transformers
gh pr checkout 30334
pip install .

Downloaded Llama 3 8B Instruct:

tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir ~/models/meta-llama/Meta-Llama-3-8B-Instruct --hf-token $HF_TOKEN

Convert to HF format

Put the code in this gist (mostly copy/pasted from this PR with a few changes to how paths are handled) in a file called ~/models/convert-model.py, then run it with the following command:

Command:

python ~/models/convert_model.py --input_dir ~/models/meta-llama/Meta-Llama-3-8B-Instruct/ --output_dir ~/models/converted-llama3  --model_size 8Bf --llama_version 3

Now the model should work

Now you should be able to load the model with from_pretrained(). Create a file ~/models/load_model.py:

from transformers import AutoModelForCausalLM
from torch import bfloat16

model = AutoModelForCausalLM.from_pretrained(
    "converted-llama3",
    torch_dtype=bfloat16,
)

And run the file to show that we can now load the model with from_pretrained():

cd ~/models
python load_model.py
kartikayk commented 2 months ago

@calmitchell617 this is AWESOME! Thanks so much for the detailed instructions.

apthagowda97 commented 2 months ago

@calmitchell617 It works !!.. now the PR is merged we can use directly. The correct way to convert Meta to GGUF is (Meta -> HF -> GGUF).

@kartikayk I tried directly converting from Meta to GGUF like you suggested by placing the tokenizer... but the answer quality is bad ( I am guessing something related to tokenizer not working properly )

monk1337 commented 2 months ago

It's not working for 70B model :/

hugebeanie commented 2 months ago

Thanks! It works for the Llama3 8b instruct conversion. However when converting the 70b instruct version, I had the following error: % python ~/models/convert_model.py --input_dir ~/models/meta-llama/Meta-Llama-3-70B-Instruct/ --output_dir ~/models/converted-llama3 --model_size 70B --llama_version 3

Saving a LlamaTokenizerFast to /Users/models/converted-llama3. {'dim': 8192, 'ffn_dim_multiplier': 1.3, 'multiple_of': 4096, 'n_heads': 64, 'n_kv_heads': 8, 'n_layers': 80, 'norm_eps': 1e-05, 'vocab_size': 128256, 'rope_theta': 500000.0} Fetching all parameters from the checkpoint at "/Users/models/meta-llama/Meta-Llama-3-70B-Instruct/. " zsh: killed python ~/models/convert_model.py --input_dir --output_dir --model_size 70B is this because I need more memory?

calmitchell617 commented 2 months ago

@hugebeanie, running convert_model.py with a 70B param model (saved in bfloat16) will take ~140GB of CPU RAM (not GPU RAM).

I observed the process using a peak of 143.2 GB CPU RAM on my machine. So, you will realistically need a few more GB than that.

One thing you can do is increase the swap memory of your system. I have done this before, you just need to Google (or ask an LLM) how to increase swap on your system. This will slow things down, but as long as you have a fast SSD, it should finish in a somewhat reasonable time.

calmitchell617 commented 2 months ago

@monk1337, what part are you have an issue with? I just ran the steps above without any issue for the 70B model.

If you're having trouble running torchtune on the 70B model, you just have to alter the configs a bit to get it to work. I have done this, and can provide an example if that's your issue.

optimass commented 1 month ago

Hello, did anyone here manage to convert the finetuned ckpts into something that can be loaded in HF's from_pretrained() ?

tambulkar commented 1 month ago

Also curious here - I want to be able to take the final checkpoints from finetuning and run inference on them using from_pretrained()

ebsmothers commented 1 month ago

Hi folks, to follow up on this: we do now support integration with PEFT as of #933 (thanks to @BenjaminBossan for the help reviewing these changes). Whenever running a LoRA fine-tune with our HF checkpointer, we will save adapter weights and config in a format that can be loaded into PEFT. The example usage would be as follows:

Run LoRA fine-tune with torchtune CLI:

tune run lora_finetune_single_device --config llama2/7B_lora_single_device \
checkpointer.output_dir=/my/output/dir

Load fine-tuned adapter weights into PEFT model with same base model from hub:

from transformers import AutoModelForCausalLM
from peft import PeftModel

# hub ID of the base model from the above fine-tune
model_id = "meta-llama/Llama-2-7b-hf" 

# output_dir from tune command
checkpoint_dir = "/my/output/dir" 

model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = PeftModel.from_pretrained(model, checkpoint_dir)

To be clear, this is not quite the same as proper HF from_pretrained support in general (i.e. this is just for loading adapter weights from a LoRA fine-tune with a pretrained model on the hub). One other caveat is that we do not yet support phi-3 in this flow (won't get into the weeds here but there is a comment on #933 explaining why for interested parties).

Hopefully this helps unblock folks using LoRA, we are working to get proper transformers from_pretrained integration as soon as possible!