Open qiuxia-alone opened 3 years ago
我是根据interact.py改写的,生成结果用了最简单的概率采样,单轮聊天TF代码如下:
tokenizer4 = AutoTokenizer.from_pretrained("vocab/",sep_token="[SEP]",pad_token="[PAD]", cls_token="[CLS]") model4 = AutoModelForCausalLM.from_pretrained("model/",from_tf=True)#model文件保存了bin模型文件 sentence = "你好啊" text_ids=tokenizer4.encode(sentence,add_special_tokens=False) text_ids=new_user_input_ids input_ids = [tokenizer4.cls_token_id] input_ids.extend(text_ids) input_ids.append(tokenizer4.sep_token_id) input_ids = tf.constant(input_ids) #(None,) input_ids = tf.expand_dims(input_ids, axis=0) #(1,None) max_len=50 response = [] for _ in range(max_len): outputs = model4(input_ids=input_ids) logits = outputs.logits #(1,None,13317) next_token_logits = tf.expand_dims(logits[0, -1, :], axis=0) #(1,13317) next_token = tf.random.categorical(tf.nn.softmax(next_token_logits,axis=-1), num_samples=1) #(1,1) next_token = tf.cast(next_token, dtype=tf.int32) if next_token == tokenizer4.sep_token_id: break response.append(next_token.numpy().item()) input_ids = tf.concat((input_ids, next_token), axis=1)
问题有两个:
但是在Pytorch框架下就正常,代码如下:
tokenizer = AutoTokenizer.from_pretrained("vocab/",sep_token="[SEP]",pad_token="[PAD]", cls_token="[CLS]") model = AutoModelForCausalLM.from_pretrained("model/model_tf/",from_tf=True) text = "你好啊" text_ids = tokenizer.encode(text, add_special_tokens=False)#(None,) input_ids = [tokenizer.cls_token_id] input_ids.extend(text_ids) input_ids.append(tokenizer.sep_token_id) input_ids = torch.tensor(input_ids).int() input_ids = input_ids.unsqueeze(0)#(1,None) max_len=50 response = [] for _ in range(max_len): outputs = model(input_ids=input_ids) logits = outputs.logits#(1,None,13317) next_token_logits = logits[0, -1, :] #(13317,) next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1) #(1,) if next_token == tokenizer.sep_token_id: break response.append(next_token.item())#(None,) input_ids = torch.cat((input_ids, next_token.unsqueeze(0)), dim=1) #(1,None)
除了模型生成的logits有出入之外,其他感觉都没问题,望指点!
我是根据interact.py改写的,生成结果用了最简单的概率采样,单轮聊天TF代码如下:
问题有两个:
但是在Pytorch框架下就正常,代码如下:
除了模型生成的logits有出入之外,其他感觉都没问题,望指点!