bigcode-project / starcoder

Home of StarCoder: fine-tuning & inference!
Apache License 2.0
7.31k stars 521 forks source link

ValueError: Cannot merge LORA layers when the model is loaded in 8-bit mode #118

Closed mathav95raj closed 1 year ago

mathav95raj commented 1 year ago

After saving the fine tuned model using the finetune.py script, running the inference script gives the above error on the line

model=model.merge_and_unload()

Inference script:

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

import os
import argparse
from chat.utils import print_time

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model_path", type=str, default="bigcode/starcoder")
    parser.add_argument("--peft_model_path", type=str, default="finetune/checkpoints/final_checkpoint")
    parser.add_argument("--prompt_path", type=str, default="prompt.txt")
    parser.add_argument("--seq_len", type=int, default=1024)
    parser.add_argument("--temperature", type=float, default=0.9)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--repetition_penalty", type=float, default=1.0)
    parser.add_argument("--savemerged", action="store_true", default=True)
    parser.add_argument("--chat", action="store_true", default=False)
    parser.add_argument("--greedy", action="store_true", default=False)

    return parser.parse_args()

def main():
    args = get_args()

    base_model = AutoModelForCausalLM.from_pretrained(
            args.base_model_path,
            return_dict=True,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto", trust_remote_code=True
        )      

    print('Loading Peft Model')
    model = PeftModel.from_pretrained(base_model, args.peft_model_path)
    print('Merging Peft Model')
    model = model.merge_and_unload()

    tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)

    if args.savemerged:
        model.save_pretrained(f"{args.peft_model_path}-merged")
        tokenizer.save_pretrained(f"{args.peft_model_path}-merged")
        print(f"Model saved to {args.peft_model_path}-merged")

    with open(args.prompt_path, encoding='latin-1') as f:
        ip = f.read()

    inputs = tokenizer(
    ip,
    return_tensors="pt",
    return_token_type_ids=False,
    ).to("cuda")

    if args.greedy:
        generate_kwargs = dict(
        temperature=0,
        max_length=args.seq_len,
        do_sample=False,
        seed=42,
        )
    else:
        generate_kwargs = dict(
        temperature=args.temperature,
        max_length=args.seq_len,
        top_p=args.top_p,
        repetition_penalty=args.repetition_penalty,
        do_sample=True,
        seed=42,
        )

    print('Inferencing..')
    if args.chat:
        from chat.dialogues import get_dialogue_template
        dialogue_template = get_dialogue_template("no_system")
        with print_time("Inference after Fine tuning "):
            outputs = model.generate(
                **inputs,
                max_new_tokens=1024,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.convert_tokens_to_ids(dialogue_template.end_token),
            )
        print(f"=== SAMPLE OUTPUT ==\n\n{tokenizer.decode(outputs[0], skip_special_tokens=False)}")
    else:
        inputs = tokenizer(ip, truncation=False, return_tensors="pt").to("cuda")
        with print_time("Inference after Fine tuning "):
            outputs = model.generate(**inputs, max_length = args.seq_len, pad_token_id=tokenizer.eos_token_id,temperature=args.temperature, top_p=args.top_p,do_sample=True)
        print(tokenizer.decode(outputs[0]))

if __name__ == "__main__" :
    main()

Command for inference:

python inf.py --base_model_path bigcode/starcoder\
    --peft_model_path finetune/checkpoints/final_checkpoint\
    --prompt_path prompt.txt\
    --seq_len 512\

In the peft_model_path --> adapter_model.bin ---> adapter_config.json

In finetune.py I am loading the model as follows,

       model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            use_auth_token=True,
            use_cache=not args.no_gradient_checkpointing,
            load_in_8bit=True,
            # torch_dtype=torch.float16,
            device_map={"": Accelerator().process_index},
        )
    model = prepare_model_for_kbit_training(model)

    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules = ["c_proj", "c_attn", "q_attn"]
    )
    model = get_peft_model(model, lora_config)
ArmelRandy commented 1 year ago

Hi. It is difficult to see what is happening without seing the trace and the content of your checkpoint folder. In any case, if your checkpoint was obtained using finetune.py you should be able to run merge peft adapters to have your peft model converted and saved locally/on the hub. You will be able to load with AutoModelForCausalLM and do the inference.

mathav95raj commented 1 year ago

Thanks for the quick reply @ArmelRandy . PFA the error trace


===================================BUG REPORT===================================
Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
================================================================================
bin /opt/conda/envs/ctp/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda113.so
/opt/conda/envs/ctp/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:149: UserWarning: /opt/conda/envs/ctp did not contain ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] as expected! Searching further paths...
  warn(msg)
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 113
CUDA SETUP: Loading binary /opt/conda/envs/ctp/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...
Loading checkpoint shards: 100%|██████████████████| 7/7 [07:31<00:00, 64.52s/it]
Loading Peft Model
Merging Peft Model
Traceback (most recent call last):
  File "/home/research/ise_stc_tuner/starcoder/inf.py", line 119, in <module>
    main()
  File "/home/research/ise_stc_tuner/starcoder/inf.py", line 63, in main
    model = model.merge_and_unload()
  File "/opt/conda/envs/ctp/lib/python3.10/site-packages/peft/tuners/lora.py", line 435, in merge_and_unload
    raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")
ValueError: Cannot merge LORA layers when the model is loaded in 8-bit mode

Inside the checkpoints/final_checkpoint folder, I am having only two files namely --> adapter_model.bin ---> adapter_config.json

Even if we don't merge, is it feasible to keep inferencing with the peft model?

ArmelRandy commented 1 year ago

I tried your inference code inf.py. I ran the command you provided with one of the local checkpoint I have at disposal. I was not able to reproduce the error you had at model.merge_and_unload(). I had the model merged on saved on my local computer. I have (bitsandbytes==0.39.0, accelerate==0.20.3, transformers==4.31.0, torch==1.13.0, peft==0.4.0). As an alternative to this method you can try to load the model with the help of AutoPeftModelForCausalLM.

mathav95raj commented 1 year ago

I tried your inference code inf.py. I ran the command you provided with one of the local checkpoint I have at disposal. I was not able to reproduce the error you had at model.merge_and_unload(). I had the model merged on saved on my local computer. I have (bitsandbytes==0.39.0, accelerate==0.20.3, transformers==4.31.0, torch==1.13.0, peft==0.4.0). As an alternative to this method you can try to load the model with the help of AutoPeftModelForCausalLM.

Able to merge with the package versions recommended here. Thank you! Closing this issue. I am attaching the pip freeze list in the environment which was giving me the error for reference.

absl-py==1.4.0
accelerate @ git+https://github.com/huggingface/accelerate.git@7954a28a71d484c4182a6b1074c1b8cc51642fc9
aiohttp==3.8.4
aiosignal==1.3.1
anyio==3.7.1
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
async-timeout==4.0.2
attrs==23.1.0
backcall==0.2.0
beautifulsoup4==4.12.2
bitsandbytes==0.40.0.post4
bleach==6.0.0
cachetools==5.3.1
certifi==2023.5.7
cffi==1.15.1
charset-normalizer==3.2.0
comm==0.1.3
datasets==2.13.1
debugpy==1.6.7
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
exceptiongroup==1.1.2
executing==1.2.0
fastjsonschema==2.17.1
filelock==3.12.2
fqdn==1.5.1
frozenlist==1.3.3
fsspec==2023.6.0
google-auth==2.22.0
google-auth-oauthlib==1.0.0
grpcio==1.56.0
huggingface-hub @ git+https://github.com/huggingface/huggingface_hub@4f1e4bbb28b58ffbc2588d2c549c6be0d083bac8
idna==3.4
ipykernel==6.24.0
ipython==8.14.0
ipython-genutils==0.2.0
ipywidgets==8.0.7
isoduration==20.11.0
jedi==0.18.2
Jinja2==3.1.2
jsonpointer==2.4
jsonschema==4.18.0
jsonschema-specifications==2023.6.1
jupyter==1.0.0
jupyter-console==6.6.3
jupyter-events==0.6.3
jupyter_client==8.3.0
jupyter_core==5.3.1
jupyter_server==2.7.0
jupyter_server_terminals==0.4.4
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.8
Markdown==3.4.3
MarkupSafe==2.1.3
matplotlib-inline==0.1.6
mistune==3.0.1
multidict==6.0.4
multiprocess==0.70.14
nbclassic==1.0.0
nbclient==0.8.0
nbconvert==7.6.0
nbformat==5.9.1
nest-asyncio==1.5.6
notebook==6.5.4
notebook_shim==0.2.3
numpy==1.25.1
oauthlib==3.2.2
overrides==7.3.1
packaging==23.1
pandas==2.0.3
pandocfilters==1.5.0
parso==0.8.3
peft @ git+https://github.com/huggingface/peft.git@c46d76ae3a323fa208bd777b9eb60a383e094d49
pexpect==4.8.0
pickleshare==0.7.5
platformdirs==3.8.1
prometheus-client==0.17.1
prompt-toolkit==3.0.39
protobuf==4.23.4
psutil==5.9.5
ptyprocess==0.7.0
pure-eval==0.2.2
pyarrow==12.0.1
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycparser==2.21
Pygments==2.15.1
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2023.3
PyYAML==6.0
pyzmq==25.1.0
qtconsole==5.4.3
QtPy==2.3.1
referencing==0.29.1
regex==2023.6.3
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.8.10
rsa==4.9
runipy==0.1.5
safetensors==0.3.1
scipy==1.11.1
Send2Trash==1.8.2
six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
stack-data==0.6.2
tensorboard==2.13.0
tensorboard-data-server==0.7.1
terminado==0.17.1
tinycss2==1.2.1
tokenizers==0.13.3
torch==1.13.1+cu116
tornado==6.3.2
tqdm==4.65.0
traitlets==5.9.0
typing_extensions==4.7.1
tzdata==2023.3
uri-template==1.3.0
urllib3==1.26.16
wcwidth==0.2.6
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.1
Werkzeug==2.3.6
widgetsnbextension==4.0.8
xxhash==3.2.0
yarl==1.9.2