with torch.autocast("cuda"):
res, history = model.chat(tokenizer=tokenizer, query=text,max_length=300)
print("--------------------------------------------------------------")
print(res)
/home/hexinyu/miniconda3/envs/nlp/lib/python3.10/site-packages/peft/tuners/lora.py:173: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False.
warnings.warn(
Explicitly passing a revision is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.
The dtype of attention mask (torch.int64) is not bool
from transformers import AutoTokenizer,AutoModel import torch from peft import get_peft_model, LoraConfig, TaskType
base_model = '/home/hexinyu/.cache/huggingface/hub/models--THUDM--chatglm-6b/snapshots/thu_chatglm' model = AutoModel.from_pretrained( base_model, trust_remote_code=True).half().cuda()
peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1, target_modules=['query_key_value',], ) model = get_peft_model(model, peft_config)
peft_path = "action.pt" model.load_state_dict(torch.load(peft_path), strict=False) model = model.cuda() model.eval()
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
text ="为什么冰红茶和柠檬茶的味道一样?"
with torch.autocast("cuda"): res, history = model.chat(tokenizer=tokenizer, query=text,max_length=300) print("--------------------------------------------------------------") print(res)
/home/hexinyu/miniconda3/envs/nlp/lib/python3.10/site-packages/peft/tuners/lora.py:173: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False. warnings.warn( Explicitly passing a
revision
is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision. The dtype of attention mask (torch.int64) is not bool