huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
16.41k stars 1.61k forks source link

Can prefix tuning be used for multi-query model like bigcode/starcoder? #713

Closed ainilian closed 1 year ago

ainilian commented 1 year ago

System Info

GPU: 2*V100(64GB) CPU: 16vCPUs 128GB

envs:

packages in environment at PyTorch-1.10.2:
Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       2_gnu    conda-forge
absl-py                   1.3.0                    pypi_0    pypi
accelerate                0.20.3                   pypi_0    pypi
addict                    2.4.0                    pypi_0    pypi
aiohttp                   3.8.4                    pypi_0    pypi
aiosignal                 1.3.1                    pypi_0    pypi
async-timeout             4.0.2                    pypi_0    pypi
asynctest                 0.13.0                   pypi_0    pypi
backcall                  0.2.0                    pypi_0    pypi
bayesian-optimization     1.0.1                    pypi_0    pypi
boto3                     1.4.4                    pypi_0    pypi
botocore                  1.5.95                   pypi_0    pypi
ca-certificates           2022.9.24            ha878542_0    conda-forge
cachetools                4.2.4                    pypi_0    pypi
certifi                   2022.9.24                pypi_0    pypi
charset-normalizer        2.0.12                   pypi_0    pypi
click                     8.1.3                    pypi_0    pypi
cycler                    0.11.0                   pypi_0    pypi
cython                    0.27.3                   pypi_0    pypi
datasets                  2.13.1                   pypi_0    pypi
debugpy                   1.6.3                    pypi_0    pypi
decorator                 5.1.1                    pypi_0    pypi
deepspeed                 0.9.5                    pypi_0    pypi
dill                      0.3.6                    pypi_0    pypi
dnspython                 2.2.1                    pypi_0    pypi
docutils                  0.19                     pypi_0    pypi
easydict                  1.9                      pypi_0    pypi
entrypoints               0.4                      pypi_0    pypi
filelock                  3.0.12                   pypi_0    pypi
flask                     2.0.1                    pypi_0    pypi
fonttools                 4.38.0                   pypi_0    pypi
frozenlist                1.3.3                    pypi_0    pypi
fsspec                    2023.1.0                 pypi_0    pypi
future                    0.18.2.post20200723173923          pypi_0    pypi
google-auth               1.35.0                   pypi_0    pypi
google-auth-oauthlib      0.4.6                    pypi_0    pypi
grpcio                    1.50.0                   pypi_0    pypi
gunicorn                  20.1.0                   pypi_0    pypi
hjson                     3.1.0                    pypi_0    pypi
huggingface-hub           0.16.4                   pypi_0    pypi
hyperopt                  0.1.2                    pypi_0    pypi
idna                      3.4                      pypi_0    pypi
importlib-metadata        5.0.0                    pypi_0    pypi
ipdb                      0.13.13                  pypi_0    pypi
ipykernel                 6.7.0                    pypi_0    pypi
ipython                   7.34.0                   pypi_0    pypi
ipython-genutils          0.2.0                    pypi_0    pypi
itsdangerous              2.1.2                    pypi_0    pypi
jedi                      0.18.1                   pypi_0    pypi
jinja2                    3.0.1                    pypi_0    pypi
jmespath                  0.10.0                   pypi_0    pypi
joblib                    1.2.0                    pypi_0    pypi
jupyter-client            7.4.4                    pypi_0    pypi
jupyter-core              4.11.2                   pypi_0    pypi
kiwisolver                1.4.4                    pypi_0    pypi
lazy-import               0.2.2                    pypi_0    pypi
ld_impl_linux-64          2.39                 hc81fddc_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-ng                 12.2.0              h65d4601_19    conda-forge
libgomp                   12.2.0              h65d4601_19    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libsqlite                 3.39.4               h753d276_0    conda-forge
libstdcxx-ng              12.2.0              h46fd767_19    conda-forge
libzlib                   1.2.13               h166bdaf_4    conda-forge
lxml                      4.9.1                    pypi_0    pypi
ma-tensorboard            1.0.0                    pypi_0    pypi
markdown                  3.4.1                    pypi_0    pypi
markupsafe                2.1.1                    pypi_0    pypi
matplotlib                3.5.1                    pypi_0    pypi
matplotlib-inline         0.1.6                    pypi_0    pypi
mmcv                      1.2.7                    pypi_0    pypi
modelarts-pytorch-model-server 1.0.4                    pypi_0    pypi
moxing-framework          2.1.0.5d9c87c8           pypi_0    pypi
multidict                 6.0.4                    pypi_0    pypi
multiprocess              0.70.14                  pypi_0    pypi
ncurses                   6.3                  h27087fc_1    conda-forge
nest-asyncio              1.5.6                    pypi_0    pypi
networkx                  2.6.3                    pypi_0    pypi
ninja                     1.11.1                   pypi_0    pypi
numpy                     1.19.5                   pypi_0    pypi
nvidia-cublas-cu11        11.10.3.66               pypi_0    pypi
nvidia-cuda-nvrtc-cu11    11.7.99                  pypi_0    pypi
nvidia-cuda-runtime-cu11  11.7.99                  pypi_0    pypi
nvidia-cudnn-cu11         8.5.0.96                 pypi_0    pypi
nvidia-ml-py3             7.352.0                  pypi_0    pypi
oauthlib                  3.2.2                    pypi_0    pypi
opencv-python             4.1.2.30                 pypi_0    pypi
openssl                   3.0.5                h166bdaf_2    conda-forge
packaging                 21.3                     pypi_0    pypi
pandas                    1.1.5                    pypi_0    pypi
parso                     0.8.3                    pypi_0    pypi
pathlib2                  2.3.7.post1              pypi_0    pypi
peft                      0.3.0                    pypi_0    pypi
pexpect                   4.8.0                    pypi_0    pypi
pickleshare               0.7.5                    pypi_0    pypi
pillow                    9.3.0                    pypi_0    pypi
pip                       23.1.2                   pypi_0    pypi
prettytable               0.7.2                    pypi_0    pypi
prompt-toolkit            3.0.31                   pypi_0    pypi
protobuf                  3.20.1                   pypi_0    pypi
psutil                    5.8.0                    pypi_0    pypi
ptyprocess                0.7.0                    pypi_0    pypi
py-cpuinfo                9.0.0                    pypi_0    pypi
pyarrow                   12.0.1                   pypi_0    pypi
pyasn1                    0.4.8                    pypi_0    pypi
pyasn1-modules            0.2.8                    pypi_0    pypi
pydantic                  1.10.11                  pypi_0    pypi
pygments                  2.13.0                   pypi_0    pypi
pymongo                   4.3.2                    pypi_0    pypi
pyparsing                 3.0.9                    pypi_0    pypi
pytest-runner             5.3.0                    pypi_0    pypi
python                    3.7.10          hf930737_104_cpython    conda-forge
python-dateutil           2.8.2                    pypi_0    pypi
pytz                      2022.6                   pypi_0    pypi
pyyaml                    5.1                      pypi_0    pypi
pyzmq                     24.0.1                   pypi_0    pypi
readline                  8.1.2                h0f457ee_0    conda-forge
regex                     2023.6.3                 pypi_0    pypi
requests                  2.27.1                   pypi_0    pypi
requests-oauthlib         1.3.1                    pypi_0    pypi
rsa                       4.9                      pypi_0    pypi
s3transfer                0.1.13                   pypi_0    pypi
safetensors               0.3.1                    pypi_0    pypi
scikit-learn              0.22.1                   pypi_0    pypi
scipy                     1.5.2                    pypi_0    pypi
setuptools                65.5.0             pyhd8ed1ab_0    conda-forge
six                       1.16.0                   pypi_0    pypi
sklearn                   0.0                      pypi_0    pypi
sortedcontainers          2.2.2                    pypi_0    pypi
sqlite                    3.39.4               h4ff8645_0    conda-forge
statistics                1.0.3.5                  pypi_0    pypi
tensorboard               2.1.1                    pypi_0    pypi
tensorboardx              2.0                      pypi_0    pypi
tk                        8.6.12               h27826a3_0    conda-forge
tokenizers                0.13.3                   pypi_0    pypi
tomli                     2.0.1                    pypi_0    pypi
torch                     1.13.1                   pypi_0    pypi
torchaudio                0.10.2                   pypi_0    pypi
torchvision               0.11.3                   pypi_0    pypi
tornado                   6.2                      pypi_0    pypi
tqdm                      4.64.1                   pypi_0    pypi
traitlets                 5.5.0                    pypi_0    pypi
transformers              4.30.2                   pypi_0    pypi
typing-extensions         4.4.0                    pypi_0    pypi
urllib3                   1.26.12                  pypi_0    pypi
watchdog                  2.0.0                    pypi_0    pypi
wcwidth                   0.2.5                    pypi_0    pypi
werkzeug                  2.2.2                    pypi_0    pypi
wheel                     0.37.1             pyhd8ed1ab_0    conda-forge
xxhash                    3.2.0                    pypi_0    pypi
xz                        5.2.6                h166bdaf_0    conda-forge
yapf                      0.32.0                   pypi_0    pypi
yarl                      1.9.2                    pypi_0    pypi
zipp                      3.10.0                   pypi_0    pypi

Who can help?

@pacman100 @younesbelkada @sayakpaul

Information

Tasks

Reproduction

error
Traceback (most recent call last):
  File "train.py", line 192, in <module>
Traceback (most recent call last):
  File "train.py", line 192, in <module>
    train()
  File "train.py", line 183, in train
    train()
  File "train.py", line 183, in train
    trainer.train()
  File "/***/lib/python3.7/site-packages/transformers/trainer.py", line 1649, in train
    trainer.train()
  File "/***/lib/python3.7/site-packages/transformers/trainer.py", line 1649, in train
        ignore_keys_for_eval=ignore_keys_for_eval,ignore_keys_for_eval=ignore_keys_for_eval,

  File "/***/lib/python3.7/site-packages/transformers/trainer.py", line 1945, in _inner_training_loop
  File "/***/lib/python3.7/site-packages/transformers/trainer.py", line 1945, in _inner_training_loop
        tr_loss_step = self.training_step(model, inputs)tr_loss_step = self.training_step(model, inputs)

  File "/***/lib/python3.7/site-packages/transformers/trainer.py", line 2772, in training_step
  File "/***/lib/python3.7/site-packages/transformers/trainer.py", line 2772, in training_step
    loss = self.compute_loss(model, inputs)
      File "/***/lib/python3.7/site-packages/transformers/trainer.py", line 2805, in compute_loss
loss = self.compute_loss(model, inputs)
  File "/***/lib/python3.7/site-packages/transformers/trainer.py", line 2805, in compute_loss
    outputs = model(**inputs)
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    outputs = model(**inputs)
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/***/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
    loss = self.module(*inputs, **kwargs)
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/peft/peft_model.py", line 731, in forward
    return forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/***/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1735, in forward
    loss = self.module(*inputs, **kwargs)
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
      File "/***/lib/python3.7/site-packages/peft/peft_model.py", line 731, in forward
return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    return self.base_model(input_ids=input_ids, past_key_values=past_key_values, **kwargs)
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 830, in forward
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 830, in forward
        return_dict=return_dict,return_dict=return_dict,

  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 690, in forward
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 690, in forward
    output_attentions=output_attentions,
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    output_attentions=output_attentions,
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 326, in forward
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 326, in forward
    output_attentions=output_attentions,
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    output_attentions=output_attentions,
  File "/***/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 247, in forward
    result = forward_call(*input, **kwargs)
  File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 247, in forward
    key_value = torch.cat((layer_past, key_value), dim=-2)
RuntimeError: Tensors must have same number of dimensions: got 5 and 3
    key_value = torch.cat((layer_past, key_value), dim=-2)
RuntimeError: Tensors must have same number of dimensions: got 5 and 3
  0%|                                                                                                                                             | 0/1250 [00:01<?, ?it/s]
[2023-07-17 11:34:10,654] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 2646522
[2023-07-17 11:34:10,789] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 2646523
[2023-07-17 11:34:10,789] [ERROR] [launch.py:321:sigkill_handler] ['/***/bin/python3.7', '-u', 'train.py', '--local_rank=1', '--train_url', '/data/zxl_obf/train/starcoder/output/', '--data_url', '/data/zxl_obf/train/starcoder/input/', '--train_type', 'prefix_tuning', '--model_sub_dir', 'starcoderbase', '--text_column', 'input', '--label_column', 'output', '--dataset_name', 'debug_utpair_200w_v1_filter', '--num_train_epochs=2', '--model_max_length=640', '--per_device_train_batch_size=2', '--per_device_eval_batch_size=1', '--gradient_accumulation_steps=4', '--evaluation_strategy=no', '--save_strategy', 'steps', '--save_steps', '10', '--save_total_limit', '100', '--learning_rate', '2e-5', '--warmup_steps', '30', '--logging_steps', '2', '--lr_scheduler_type', 'cosine', '--gradient_checkpointing', 'False', '--deepspeed', 'deepspeed_config.json', '--fp16', 'True'] exits with return code = 1
/***/lib/python3.7/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.12) or chardet (5.1.0)/charset_normalizer (2.0.12) doesn't match a supported version!
  RequestsDependencyWarning)
add print details for error

File "/***/lib/python3.7/site-packages/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py", line 247, in forward key_value = torch.cat((layer_past, key_value), dim=-2) RuntimeError: Tensors must have same number of dimensions: got 5 and 3

I add some print-codes arround the "key_value = torch.cat((layer_past, key_value), dim=-2)"

     if layer_past is not None:
          print(f'key_value.shape : {key_value.shape}')
          print(f'layer_past size:: {layer_past.shape}')
          key_value = torch.cat((layer_past, key_value), dim=-2)

the print details"

key_value.shape : torch.Size([2, 640, 256])
layer_past size:: torch.Size([2, 2, 48, 30, 128])
codes
dataset_name = data_args.dataset_name
dataset = load_dataset('text', 
                       data_files={'train': os.path.join(args.data_url, 'dataset', dataset_name+'_train.jsonl'),
                                   'valid': os.path.join(args.data_url, 'dataset', dataset_name+'_valid.jsonl')})

model_name_or_path = os.path.join(args.data_url, 'checkpoint', args.model_sub_dir)
tokenizer_name_or_path = model_name_or_path

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    fn_kwargs={"tokenizer": tokenizer})

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["valid"]

logger.info('start load model')
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)

logger.info('load peft prefix tuning model')
peft_config = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM, 
    num_virtual_tokens=30)
model = get_peft_model(model, peft_config)

trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset)
trainer.train()

logger.info('start to save model')
trainer.save_state()
trainer.save_model(output_dir=training_args.output_dir)
huggingface trainer arguments:
--num_train_epochs=2 \
--model_max_length=640 \
--per_device_train_batch_size=2 \
--per_device_eval_batch_size=1 \
--gradient_accumulation_steps=4 \
--evaluation_strategy="no" \
--save_strategy "steps" \
--save_steps 10 \
--save_total_limit 100 \
--learning_rate 2e-5 \
--warmup_steps 30 \
--logging_steps 2 \
--lr_scheduler_type "cosine" \
--gradient_checkpointing False \
--deepspeed deepspeed_config.json \
--fp16 True
use deepspeed zero3-offload

Expected behavior

Can prefix tuning be used for multi-query model like bigcode/starcoder? if prefix tuning support bigcode/starcoder, how to set the prefix tuning config?

ainilian commented 1 year ago

Looking forward to your reply! @pacman100 @younesbelkada @sayakpaul

ainilian commented 1 year ago

Looking forward to your reply! @pacman100 @younesbelkada @sayakpaul

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

dsoselia commented 11 months ago

Was wondering if this has been addressed.