zjunlp / IEPile

[ACL 2024] IEPile: A Large-Scale Information Extraction Corpus
http://oneke.openkg.cn/
Other
157 stars 14 forks source link

llama-2-13b-chat-hf +llama2-13b-iepile-lora 4bit运行时,程序报错 #11

Closed LeonNerd closed 4 months ago

LeonNerd commented 4 months ago

import torch from transformers import BitsAndBytesConfig import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

from transformers import ( AutoConfig, AutoTokenizer, AutoModelForCausalLM, GenerationConfig ) from peft import PeftModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = /IEPile/models/pretrain/llama-2-13b-chat-hf' lora_path = '/IEPile/models/pretrain/llama2-13b-iepile-lora' config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

quantization_config=BitsAndBytesConfig(
load_in_4bit=True, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) model = AutoModelForCausalLM.from_pretrained( model_path, config=config, device_map="auto", quantization_config=quantization_config, torch_dtype=torch.bfloat16, trust_remote_code=True, )

model = PeftModel.from_pretrained( model, lora_path, )

model.to(device)

model.eval()

sintruct = "{\"instruction\": \"You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.\", \"schema\": [\"person\", \"organization\", \"else\", \"location\"], \"input\": \"284 Robert Allenby ( Australia ) 69 71 71 73 , Miguel Angel Martin ( Spain ) 75 70 71 68 ( Allenby won at first play-off hole )\"}" sintruct = '' + sintruct + ''

input_ids = tokenizer.encode(sintruct, return_tensors="pt").to(device) input_length = input_ids.size(1) print(input_ids) print(input_length) generation_output = model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_length=512, max_new_tokens=256, return_dict_in_generate=True)) generation_output = generation_output.sequences[0] generation_output = generation_output[input_length:] output = tokenizer.decode(generation_output, skip_special_tokens=True)

print(output) 报错: Traceback (most recent call last): File "/home/admin/wangsj/workspace/IEPile/script/infer_4bit.py", line 50, in generation_output = model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_length=512, max_new_tokens=256, return_dict_in_generate=True)) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/peft/peft_model.py", line 977, in generate outputs = self.base_model.generate(kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/transformers/generation/utils.py", line 1602, in generate return self.greedy_search( File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/transformers/generation/utils.py", line 2450, in greedy_search outputs = self( File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 820, in forward outputs = self.model( File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 708, in forward layer_outputs = decoder_layer( File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(*args, kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 424, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/accelerate/hooks.py", line 165, in new_forward output = old_forward(args, kwargs) File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 311, in forward query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] File "/root/anaconda3/envs/IEPile/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 311, in query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] RuntimeError: mat1 and mat2 shapes cannot be multiplied (137x5120 and 1x2560)

Package Version


accelerate 0.21.0 aiohttp 3.9.5 aiosignal 1.3.1 async-timeout 4.0.3 attrs 23.2.0 bitsandbytes 0.39.1 certifi 2024.2.2 charset-normalizer 3.3.2 cmake 3.29.2 datasets 2.16.1 dill 0.3.7 filelock 3.14.0 frozenlist 1.4.1 fsspec 2023.10.0 huggingface-hub 0.20.3 idna 3.7 jieba 0.42.1 Jinja2 3.1.3 lit 18.1.4 MarkupSafe 2.1.5 mpmath 1.3.0 multidict 6.0.5 multiprocess 0.70.15 networkx 3.2.1 numpy 1.24.4 nvidia-cublas-cu11 11.10.3.66 nvidia-cuda-cupti-cu11 11.7.101 nvidia-cuda-nvrtc-cu11 11.7.99 nvidia-cuda-runtime-cu11 11.7.99 nvidia-cudnn-cu11 8.5.0.96 nvidia-cufft-cu11 10.9.0.58 nvidia-curand-cu11 10.2.10.91 nvidia-cusolver-cu11 11.4.0.1 nvidia-cusparse-cu11 11.7.4.91 nvidia-nccl-cu11 2.14.3 nvidia-nvtx-cu11 11.7.91 packaging 24.0 pandas 2.2.2 peft 0.4.0 pip 23.3.1 protobuf 3.20.1 psutil 5.9.8 pyarrow 16.0.0 pyarrow-hotfix 0.6 pydantic 1.10.7 python-dateutil 2.9.0.post0 pytz 2024.1 PyYAML 6.0.1 regex 2024.4.28 requests 2.31.0 rouge-chinese 1.0.3 safetensors 0.4.3 scipy 1.9.1 sentencepiece 0.1.98 setuptools 68.2.2 six 1.16.0 sympy 1.12 tiktoken 0.6.0 tokenizers 0.13.3 torch 2.0.0 tqdm 4.66.2 transformers 4.33.0 triton 2.0.0 typing_extensions 4.11.0 tzdata 2024.1 urllib3 2.2.1 wheel 0.41.2 xxhash 3.4.1 yarl 1.9.4

guihonghao commented 4 months ago

你好,llama2-iepile的指令格式是

system_prompt = "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
sintruct = "{\"instruction\": \"You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.\", \"schema\": [\"person\", \"organization\", \"else\", \"location\"], \"input\": \"284 Robert Allenby ( Australia ) 69 71 71 73 , Miguel Angel Martin ( Spain ) 75 70 71 68 ( Allenby won at first play-off hole )\"}"
sintruct = '[INST] ' + system_prompt + sintruct + ' [/INST]'

根据报错 RuntimeError: mat1 and mat2 shapes cannot be multiplied (137x5120 and 1x2560),请检查llama-2-13b-chat-hf的config.json 文件中的pretraining_tp值是否为1,若不为1,请修改为1.

LeonNerd commented 4 months ago

你好,llama2-iepile的指令格式是

system_prompt = "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
sintruct = "{\"instruction\": \"You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.\", \"schema\": [\"person\", \"organization\", \"else\", \"location\"], \"input\": \"284 Robert Allenby ( Australia ) 69 71 71 73 , Miguel Angel Martin ( Spain ) 75 70 71 68 ( Allenby won at first play-off hole )\"}"
sintruct = '[INST] ' + system_prompt + sintruct + ' [/INST]'

根据报错 RuntimeError: mat1 and mat2 shapes cannot be multiplied (137x5120 and 1x2560),请检查llama-2-13b-chat-hf的config.json 文件中的pretraining_tp值是否为1,若不为1,请修改为1.

是的,格式刚才粘贴的时候没注意,粘成了baichuan。非常感谢~pretraining_tp=1解决了这个问题。