Closed cj401 closed 6 months ago
这个代码很不错 多谢
我想请问能否分享一下 /track1/train_valid.json 或是它的具体格式 我想跑一下整个流程。
不用纠结这个,不需要这个文件的。推理代码可以自己修改,
with open('../track1/train_valid.json','r') as f:
data=json.load(f)
ans_lst=[]
target_lst=[]
for p in data[:100]:
# run generation
prompt=p['question']
x=tokenizer.encode(prompt,add_special_tokens=False)+[tokenizer.special_tokens['<bos>']]
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
target=p['answer']
target_lst.append(target)
with torch.no_grad():
with ctx:
y = model.generate(x, 2, max_new_tokens, temperature=temperature, top_k=top_k)
#
answer=tokenizer.decode(y[0].tolist())
answer=answer.replace(prompt,'')
ans_lst.append(answer)
print('[prompt]:',prompt)
print('[answer]:',answer)
print('---------------')
上面这段代码可以改成
data=['帮我写一首冬天的诗歌',
'帮我写一封邮件给李华',
'杭州有哪些旅游景点,帮我介绍一下',
'青霉素的作用是什么,有什么禁忌吗,哪些人应该禁用?',
'最近老是觉得肚子疼怎么办,晚上疼得更厉害']
for prompt in data:
# run generation
x=tokenizer.encode(prompt,add_special_tokens=False)+[tokenizer.special_tokens['<bos>']]
x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
with torch.no_grad():
with ctx:
y = model.generate(x, 2, max_new_tokens, temperature=temperature, top_k=top_k)
#
answer=tokenizer.decode(y[0].tolist())
answer=answer.replace(prompt,'')
print('[prompt]:',prompt)
print('[answer]:',answer)
print('---------------')
这个代码很不错 多谢
我想请问能否分享一下 /track1/train_valid.json 或是它的具体格式 我想跑一下整个流程。