yuanzhoulvpi2017 / zero_nlp

中文nlp解决方案(大模型、数据、模型、训练、推理)
MIT License
2.93k stars 360 forks source link

合并Lora权重后的模型不生成回答了 #84

Closed heccxixi closed 1 year ago

heccxixi commented 1 year ago

from transformers import AutoTokenizer,AutoModel import torch from peft import get_peft_model, LoraConfig, TaskType

base_model = '/home/hexinyu/.cache/huggingface/hub/models--THUDM--chatglm-6b/snapshots/thu_chatglm' model = AutoModel.from_pretrained( base_model, trust_remote_code=True).half().cuda()

peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=['query_key_value',], ) model = get_peft_model(model, peft_config)

peft_path = "action.pt" model.load_state_dict(torch.load(peft_path), strict=False) model = model.cuda() model.eval()

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

text ="为什么冰红茶和柠檬茶的味道一样?"

with torch.autocast("cuda"): res, history = model.chat(tokenizer=tokenizer, query=text,max_length=300) print("--------------------------------------------------------------") print(res)

/home/hexinyu/miniconda3/envs/nlp/lib/python3.10/site-packages/peft/tuners/lora.py:173: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False. warnings.warn( Explicitly passing a revision is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision. The dtype of attention mask (torch.int64) is not bool