GraphPKU / PiSSA

PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models(NeurIPS 2024 Spotlight)
https://arxiv.org/abs/2404.02948
261 stars 9 forks source link

Mismatches in tensor sizes #10

Closed poaboagye closed 5 months ago

poaboagye commented 5 months ago

Hi, when I run the pissa.sh with this (below), I get an error, probably due to mismatches in tensor sizes.

--------pissa.sh------- python pissa.py \ --model_name_or_path meta-llama/Llama-2-7b-hf \ --output_dir ./output/pissa-llama-2-7b-r128 \ --init_lora_weights pissa \ --lora_r 128 \ --data_path meta-math/MetaMathQA \ --dataset_split "train[:100000]"\ --dataset_field query response \ --num_train_epochs 1 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 128 \ --save_strategy "steps" \ --save_steps 100 \ --save_total_limit 1 \ --learning_rate 2e-5 \ --weight_decay 0. \ --warmup_ratio 0.03 \ --lr_scheduler_type "cosine" \ --logging_steps 1 \ --bf16 True \ --tf32 True \ --report_to wandb

CUDA_VISIBLE_DEVICES=0 python merge_adapter_to_base_model.py --base_mode meta-llama/Llama-2-7b-hf --adapter ./output/pissa-llama-2-7b-r128/ft/ --output_path ./output/pissa-llama-2-7b-r128 CUDA_VISIBLE_DEVICES=0 python inference/gsm8k_inference.py --model ./output/pissa-llama-2-7b-r128 CUDA_VISIBLE_DEVICES=0 python inference/MATH_inference.py --model ./output/pissa-llama-2-7b-r128

--------Error-------

/home/Ubuntu/anaconda3/envs/embComp/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True. warnings.warn( Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00, 2.34s/it] Traceback (most recent call last): File "/home/Ubuntu/Documents/embComp/PiSSA/merge_adapter_to_base_model.py", line 16, in model = PeftModel.from_pretrained(model, args.adapter, config=lora_config) File "/home/Ubuntu/anaconda3/envs/embComp/lib/python3.10/site-packages/peft/peft_model.py", line 430, in from_pretrained model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) File "/home/Ubuntu/anaconda3/envs/embComp/lib/python3.10/site-packages/peft/peft_model.py", line 988, in load_adapter load_result = set_peft_model_state_dict( File "/home/Ubuntu/anaconda3/envs/embComp/lib/python3.10/site-packages/peft/utils/save_and_load.py", line 353, in set_peft_model_state_dict load_result = model.load_state_dict(peft_model_state_dict, strict=False) File "/home/Ubuntu/anaconda3/envs/embComp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for PeftModelForCausalLM: size mismatch for base_model.model.model.embed_tokens.weight: copying a param with shape torch.Size([32001, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]). size mismatch for base_model.model.lm_head.weight: copying a param with shape torch.Size([32001, 4096]) from checkpoint, the shape in current model is torch.Size([32000, 4096]). promt ===== Below is an instruction that describes a task. Write a response that appropriately completes the request.

fxmeng commented 5 months ago

Thank you for your reminder. In a certain upgraded version of transformers or PEFT, even with the use of an adapter, both embed_tokens and lm_head are updated. Therefore, we use the following code to prevent embed_tokens and lm_head from being updated, ensuring that only the PiSSA adapter is updated:

print("<=======params.requires_grad=======>")
for name, params in model.named_parameters():
    if "embed_tokens" in name or "lm_head" in name:
        params.requires_grad=False
    if params.requires_grad:
        print(name)
poaboagye commented 5 months ago

Thank you very much for your response. I'll run the new pissa.py then.

poaboagye commented 5 months ago

Update:

Everything works now. Thank you!

fxmeng commented 5 months ago

You are welcome, enjoy 🍕.