artidoro / qlora

QLoRA: Efficient Finetuning of Quantized LLMs
https://arxiv.org/abs/2305.14314
MIT License
10.06k stars 823 forks source link

Loading Lora Adapter weights into 4bit model to continue fine tuning #145

Closed simsim314 closed 1 year ago

simsim314 commented 1 year ago

In this colab you show how to load adapter and merge it with initial model. Notice it loads in 16bits format, so for 65B parameters it would need something like 130GB of RAM.

model_name = "decapoda-research/llama-7b-hf"
adapters_name = 'timdettmers/guanaco-7b'

print(f"Starting to load the model {model_name} into memory")

m = AutoModelForCausalLM.from_pretrained(
    model_name,
    #load_in_4bit=True,
    torch_dtype=torch.bfloat16,
    device_map={"": 0}
)
m = PeftModel.from_pretrained(m, adapters_name)
m = m.merge_and_unload()

Is there a way to load the model adapter as an adapter, just like it was in the training time? Also if someone wants to continue training after having the adapter using this sample, he needs to merge the adapter into the model on very large machine, and then create exact same adapter on weaker machine, just to continue to train from where he (or someone else) already stopped. Is there a way to load the adapter into 4bit state base model? Somewhere here (inside the FineTuning colab):

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8, 
    lora_alpha=32, 
    target_modules=["query_key_value"], 
    lora_dropout=0.05, 
    bias="none", 
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
#Add load adapter weights somwhere here... 
print_trainable_parameters(model)
simsim314 commented 1 year ago

Nevermind... you can use from_pretrained function. Like this:

config = LoraConfig.from_pretrained("timdettmers/guanaco-65b")
config.inference_mode = False 

model = get_peft_model(model, config)
print_trainable_parameters(model)
abcbdf commented 1 year ago

Nevermind... you can use from_pretrained function. Like this:

config = LoraConfig.from_pretrained("timdettmers/guanaco-65b")
config.inference_mode = False 

model = get_peft_model(model, config)
print_trainable_parameters(model)

Hi, @simsim314 correct me if I'm wrong. I think by doing so you only loaded the config of "timdettmers/guanaco-65b" instead of the trained adapter's weight. "m = PeftModel.from_pretrained(m, adapters_name)" is the right way to load the config and the weight together

abcbdf commented 1 year ago

Oh, maybe should also set is_trainable=True

simsim314 commented 1 year ago

@abcbdf No, the line that loads the model weights is: model = get_peft_model(model, config)

Your method should work too: model = PeftModel.from_pretrained(model, adapter_name, is_trainable=True)

justin4ai commented 1 month ago

@simsim314 Hello, I'm trying the same thing to yours(PeftModel.from_pretrained or get_peft_model) with my own saved qlora adapter but the error occurs:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.26s/it]
Traceback (most recent call last):
  File "/home/nas4_user/seungillee/justin_intern/draw-your-day/LLMs/finetuned_llama.py", line 18, in <module>
    tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model_path)  # Load the tokenizer from the same path
  File "/home/seungillee/anaconda3/envs/draw-your-day/lib/python3.9/site-packages/transformers/models/auto/tokenization_auto.py", line 864, in from_pretrained
    config = AutoConfig.from_pretrained(
  File "/home/seungillee/anaconda3/envs/draw-your-day/lib/python3.9/site-packages/transformers/models/auto/configuration_auto.py", line 1038, in from_pretrained
    raise ValueError(
ValueError: Unrecognized model in ./fine_tuned_model. Should have a `model_type` key in its config.json, or contain one of the following strings in its name: albert, align, altclip, audio-spectrogram-transformer, autoformer, bark, bart, beit, bert, bert-generation, big_bird, bigbird_pegasus, biogpt, bit, blenderbot, blenderbot-small, blip, blip-2, bloom, bridgetower, bros, camembert, canine, chameleon, chinese_clip, chinese_clip_vision_model, clap, clip, clip_text_model, clip_vision_model, clipseg, clvp, code_llama, codegen, cohere, conditional_detr, convbert, convnext, convnextv2, cpmant, ctrl, cvt, dac, data2vec-audio, data2vec-text, data2vec-vision, dbrx, deberta, deberta-v2, decision_transformer, deformable_detr, deit, depth_anything, deta, detr, dinat, dinov2, distilbert, donut-swin, dpr, dpt, efficientformer, efficientnet, electra, encodec, encoder-decoder, ernie, ernie_m, esm, falcon, falcon_mamba, fastspeech2_conformer, flaubert, flava, fnet, focalnet, fsmt, funnel, fuyu, gemma, gemma2, git, glpn, gpt-sw3, gpt2, gpt_bigcode, gpt_neo, gpt_neox, gpt_neox_japanese, gptj, gptsan-japanese, granite, granitemoe, graphormer, grounding-dino, groupvit, hiera, hubert, ibert, idefics, idefics2, imagegpt, informer, instructblip, instructblipvideo, jamba, jetmoe, jukebox, kosmos-2, layoutlm, layoutlmv2, layoutlmv3, led, levit, lilt, llama, llava, llava_next, llava_next_video, llava_onevision, longformer, longt5, luke, lxmert, m2m_100, mamba, mamba2, marian, markuplm, mask2former, maskformer, maskformer-swin, mbart, mctct, mega, megatron-bert, mgp-str, mimi, mistral, mixtral, mllama, mobilebert, mobilenet_v1, mobilenet_v2, mobilevit, mobilevitv2, mpnet, mpt, mra, mt5, musicgen, musicgen_melody, mvp, nat, nemotron, nezha, nllb-moe, nougat, nystromformer, olmo, olmoe, omdet-turbo, oneformer, open-llama, openai-gpt, opt, owlv2, owlvit, paligemma, patchtsmixer, patchtst, pegasus, pegasus_x, perceiver, persimmon, phi, phi3, pix2struct, pixtral, plbart, poolformer, pop2piano, prophetnet, pvt, pvt_v2, qdqbert, qwen2, qwen2_audio, qwen2_audio_encoder, qwen2_moe, qwen2_vl, rag, realm, recurrent_gemma, reformer, regnet, rembert, resnet, retribert, roberta, roberta-prelayernorm, roc_bert, roformer, rt_detr, rt_detr_resnet, rwkv, sam, seamless_m4t, seamless_m4t_v2, segformer, seggpt, sew, sew-d, siglip, siglip_vision_model, speech-encoder-decoder, speech_to_text, speech_to_text_2, speecht5, splinter, squeezebert, stablelm, starcoder2, superpoint, swiftformer, swin, swin2sr, swinv2, switch_transformers, t5, table-transformer, tapas, time_series_transformer, timesformer, timm_backbone, trajectory_transformer, transfo-xl, trocr, tvlt, tvp, udop, umt5, unispeech, unispeech-sat, univnet, upernet, van, video_llava, videomae, vilt, vipllava, vision-encoder-decoder, vision-text-dual-encoder, visual_bert, vit, vit_hybrid, vit_mae, vit_msn, vitdet, vitmatte, vits, vivit, wav2vec2, wav2vec2-bert, wav2vec2-conformer, wavlm, whisper, xclip, xglm, xlm, xlm-prophetnet, xlm-roberta, xlm-roberta-xl, xlnet, xmod, yolos, yoso, zoedepth

image

Model_4bit is a 4bit-quantized llama. Any idea about this happening? Appreciate your help in advance!