yangjianxin1 / GPT2-chitchat

GPT2 for Chinese chitchat/用于中文闲聊的GPT2模型(实现了DialoGPT的MMI思想)
2.99k stars 680 forks source link

在tensorflow中使用50w chichat预模型时生成结果不佳 #83

Open qiuxia-alone opened 3 years ago

qiuxia-alone commented 3 years ago

我是根据interact.py改写的,生成结果用了最简单的概率采样,单轮聊天TF代码如下:

tokenizer4 = AutoTokenizer.from_pretrained("vocab/",sep_token="[SEP]",pad_token="[PAD]", cls_token="[CLS]")
model4 = AutoModelForCausalLM.from_pretrained("model/",from_tf=True)#model文件保存了bin模型文件

sentence = "你好啊"
text_ids=tokenizer4.encode(sentence,add_special_tokens=False)
text_ids=new_user_input_ids
input_ids = [tokenizer4.cls_token_id]
input_ids.extend(text_ids)
input_ids.append(tokenizer4.sep_token_id)
input_ids = tf.constant(input_ids) #(None,)
input_ids = tf.expand_dims(input_ids, axis=0) #(1,None)

max_len=50
response = []
for _ in range(max_len):
    outputs = model4(input_ids=input_ids)
    logits = outputs.logits #(1,None,13317)
    next_token_logits = tf.expand_dims(logits[0, -1, :], axis=0) #(1,13317)
    next_token = tf.random.categorical(tf.nn.softmax(next_token_logits,axis=-1), num_samples=1) #(1,1)
    next_token = tf.cast(next_token, dtype=tf.int32)
    if next_token == tokenizer4.sep_token_id:
        break
    response.append(next_token.numpy().item())
    input_ids = tf.concat((input_ids, next_token), axis=1)

问题有两个:

  1. 循环生成token时,next_token得不到sep_token_id,一直循环到max_len才结束;
  2. 生成的response解码出来有点离谱: tokenizer4.decode(response) Out[216]: '獒 鯉ますのて ni system 猕 視 艳 r9 揣 迈 臥ads 秦 apps 賊 eur 160 钳 ん 痺 俳 eb 荥 盆 one 擲 ef 181osa 琰 宅 羅 倚 穌llowument ◠ 1969 162 139 泷 碰 ║ensします 侄 114 懵キ'

但是在Pytorch框架下就正常,代码如下:

tokenizer = AutoTokenizer.from_pretrained("vocab/",sep_token="[SEP]",pad_token="[PAD]", cls_token="[CLS]")
model = AutoModelForCausalLM.from_pretrained("model/model_tf/",from_tf=True)

text = "你好啊"
text_ids = tokenizer.encode(text, add_special_tokens=False)#(None,)
input_ids = [tokenizer.cls_token_id]
input_ids.extend(text_ids)
input_ids.append(tokenizer.sep_token_id)
input_ids = torch.tensor(input_ids).int()
input_ids = input_ids.unsqueeze(0)#(1,None)

max_len=50
response = []
for _ in range(max_len):
    outputs = model(input_ids=input_ids)
    logits = outputs.logits#(1,None,13317)
    next_token_logits = logits[0, -1, :] #(13317,)
    next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) #(1,)
    if next_token == tokenizer.sep_token_id:
        break
    response.append(next_token.item())#(None,)
    input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1) #(1,None)

除了模型生成的logits有出入之外,其他感觉都没问题,望指点!