huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
15.93k stars 1.55k forks source link

Zero Trainable parameters when using get_peft_model() and a custom adapter name #1346

Closed psych0v0yager closed 8 months ago

psych0v0yager commented 8 months ago

System Info

Package                  Version
------------------------ ------------
**accelerate               0.23.0**
aiohttp                  3.8.6
aiosignal                1.3.1
appdirs                  1.4.4
asttokens                2.4.0
async-timeout            4.0.3
attrs                    23.1.0
**auto-gptq                0.6.0**
backcall                 0.2.0
**bitsandbytes             0.42.0**
certifi                  2023.7.22
charset-normalizer       3.3.0
click                    8.1.7
coloredlogs              15.0.1
comm                     0.1.4
contourpy                1.2.0
cycler                   0.12.1
**datasets                 2.14.5**
debugpy                  1.8.0
decorator                5.1.1
dill                     0.3.7
docker-pycreds           0.4.0
docstring-parser         0.15
exceptiongroup           1.1.3
executing                2.0.0
filelock                 3.12.4
fonttools                4.47.0
frozenlist               1.4.0
fsspec                   2023.6.0
gekko                    1.0.6
gitdb                    4.0.11
GitPython                3.1.41
**huggingface-hub          0.20.2**
humanfriendly            10.0
idna                     3.4
ipykernel                6.25.2
ipython                  8.16.1
ipywidgets               8.1.1
jedi                     0.19.1
Jinja2                   3.1.2
jupyter_client           8.3.1
jupyter_core             5.3.2
jupyterlab-widgets       3.0.9
kiwisolver               1.4.5
lorax-client             0.1.0
markdown-it-py           3.0.0
MarkupSafe               2.1.3
matplotlib               3.8.2
matplotlib-inline        0.1.6
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.0.4
multiprocess             0.70.15
nest-asyncio             1.5.8
networkx                 3.1
numpy                    1.26.0
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.18.1
nvidia-nvjitlink-cu12    12.2.140
nvidia-nvtx-cu12         12.1.105
openai                   0.28.1
optimum                  1.16.1
packaging                23.2
pandas                   2.1.1
parso                    0.8.3
**peft                     0.7.1**
pexpect                  4.8.0
pickleshare              0.7.5
Pillow                   10.0.1
pip                      23.2.1
platformdirs             3.11.0
prompt-toolkit           3.0.39
protobuf                 4.24.4
psutil                   5.9.5
ptyprocess               0.7.0
pure-eval                0.2.2
pyarrow                  13.0.0
pydantic                 1.10.13
Pygments                 2.16.1
pyparsing                3.1.1
python-dateutil          2.8.2
pytz                     2023.3.post1
PyYAML                   6.0.1
pyzmq                    25.1.1
regex                    2023.10.3
requests                 2.31.0
rich                     13.7.0
rouge                    1.0.1
safetensors              0.4.0
scipy                    1.11.4
sentencepiece            0.1.99
sentry-sdk               1.39.2
setproctitle             1.3.3
setuptools               68.0.0
shtab                    1.6.5
six                      1.16.0
smmap                    5.0.1
stack-data               0.6.3
sympy                    1.12
**tokenizers               0.15.0**
**torch                    2.1.0**
**torchvision              0.16.0**
tornado                  6.3.3
tqdm                     4.66.1
traitlets                5.11.2
**transformers             4.36.2**
triton                   2.1.0
**trl                      0.7.9**
typing_extensions        4.8.0
tyro                     0.6.3
tzdata                   2023.3
urllib3                  2.0.6
wandb                    0.16.2
wcwidth                  0.2.8
wheel                    0.41.2
widgetsnbextension       4.0.9
xxhash                   3.4.1
yarl                     1.9.2

Who can help?

@pacman100 @younesbelkada

Information

Tasks

Reproduction

from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig
from peft import prepare_model_for_kbit_training
import torch

model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"
quantization_config_loading = GPTQConfig(bits=4, disable_exllama=True)
model = AutoModelForCausalLM.from_pretrained(model_id,quantization_config=quantization_config_loading, device_map="auto")

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

from peft import LoraConfig, get_peft_model
config = LoraConfig(
    r=8,
    lora_alpha=32,
    # target_modules=["k_proj","o_proj","q_proj","v_proj"],
    target_modules=[
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "down_proj",
        "up_proj",
        "gate_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model=model, peft_config=config, adapter_name="trainable")
model.print_trainable_parameters()

Output:
trainable params: 0 || all params: 283,381,760 || trainable%: 0.0

When model = get_peft_model(model=model, peft_config=config, adapter_name="trainable") is changed to model = get_peft_model(model=model, peft_config=config) I get the following output

Output: trainable params: 20,971,520 || all params: 283,381,760 || trainable%: 7.400448074004481

Expected behavior

I would like the custom names to work identically to the default name. This helps when using Option 3 with PEFT as listed in the DPO docs (https://huggingface.co/docs/trl/main/en/dpo_trainer). Many model authors have been merging their adapter weights after SFTing their models, requiring me to create a new adapter to use this option.

BenjaminBossan commented 8 months ago

Thanks a lot for reporting this issue, a PR to fix this is under way (#1347).