Open kbwzy opened 1 year ago
训练完成后在output目录有如下生成内容 推理代码如下: from transformers import AutoModel import torch from transformers import AutoTokenizer from peft import PeftModel import json from cover_alpaca2jsonl import format_example
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True,load_in_8bit=False) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) model = PeftModel.from_pretrained(model, "./output") instructions = json.load(open("data/alpaca_data.json")) answers = []
with torch.no_grad(): for idx, item in enumerate(instructions[:3]): feature = format_example(item) input_text = feature['context'] ids = tokenizer.encode(input_text) input_ids = torch.LongTensor([ids]) out = model.generate( input_ids=input_ids, max_length=150, do_sample=False, temperature=0 ) out_text = tokenizer.decode(out[0]) answer = out_text.replace(input_text, "").replace("\nEND", "").strip() item['infer_answer'] = answer print(out_text) print(f"### {idx+1}.Answer:\n", item.get('output'), '\n\n') answers.append({'index': idx, **item}) 报错:RuntimeError: mixed dtype (CPU): expect input to have scalar type of BFloat16
load_in_8bit=False改成True,this works for me
训练完成后在output目录有如下生成内容
推理代码如下:
from transformers import AutoModel
import torch
from transformers import AutoTokenizer
from peft import PeftModel
import json
from cover_alpaca2jsonl import format_example
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True,load_in_8bit=False) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) model = PeftModel.from_pretrained(model, "./output") instructions = json.load(open("data/alpaca_data.json")) answers = []
with torch.no_grad(): for idx, item in enumerate(instructions[:3]): feature = format_example(item) input_text = feature['context'] ids = tokenizer.encode(input_text) input_ids = torch.LongTensor([ids]) out = model.generate( input_ids=input_ids, max_length=150, do_sample=False, temperature=0 ) out_text = tokenizer.decode(out[0]) answer = out_text.replace(input_text, "").replace("\nEND", "").strip() item['infer_answer'] = answer print(out_text) print(f"### {idx+1}.Answer:\n", item.get('output'), '\n\n') answers.append({'index': idx, **item}) 报错:RuntimeError: mixed dtype (CPU): expect input to have scalar type of BFloat16