pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.91k stars 351 forks source link

Model merging scripts? #1179

Open suraj-srinivas opened 1 month ago

suraj-srinivas commented 1 month ago

Hi,

For LORA fine-tuning, are there ways to save only the adapter models and not the full model files? More importantly, what are the easiest ways to perform model merging, given a base model and an adapter?

I am working on a project involving a lot of analysis of fine-tuned models, and saving only the adapters + having on-the-fly model merging functionality would really help.

Thanks for the excellent library!

pbontrager commented 1 month ago

You can look inside of our lora_finetuning recipes to see how we handle checkpointing here. The important function is get_merged_lora_ckpt which will merge your adapter weights with your model weights. The recipe currently saves both the adapters and the merged checkpoints. If you want to save space, you can copy this recipe and modify the save_checkpoint method to only save the adapter weights. Then after training you can merge any adapter you choose with get_merged_lora_ckpt. Let me know if you run into any issues with that.

@ebsmothers I noticed that get_lora_module_names, validate_state_dict_for_lora, get_merged_lora_ckpt, disable_adapter, and validate_missing_and_unexpected_for_lora are not included in the peft init or in the documentation. Could you add those when you get a chance?