Open cdj0311 opened 2 years ago
你和huggingface的GPT结果对比过么?
你和huggingface的GPT结果对比过么?
对比过,huggingface的没问题。
您能post一下可复现错误的代码么。我怀疑是模型载入问题?因为GPT2是在wechat线上服务用过的,按理说是可以work的。
您能post一下可复现错误的代码么。我怀疑是模型载入问题?因为GPT2是在wechat线上服务用过的,按理说是可以work的。
import torch import transformers import turbo_transformers import enum import time import numpy
class LoadType(enum.Enum): PYTORCH = "PYTORCH" PRETRAINED = "PRETRAINED" NPZ = "NPZ"
def test(loadtype: LoadType, use_cuda: bool): test_device = torch.device('cuda:0') if use_cuda else torch.device('cpu:0') cfg = transformers.GPT2Config() model = transformers.GPT2Model(cfg).to(test_device) model.eval() torch.set_grad_enabled(False) cfg = model.config input_ids1 = torch.tensor([[12166, 10699, 16752, 4454]], dtype=torch.long).to(test_device) input_ids2 = torch.tensor([[12166, 10699, 16752, 4454]], dtype=torch.long).to(test_device) start_time = time.time() past = None
for _ in range(32):
t1 = time.time()
torch_res = model.forward(input_ids1, past_key_values=past)
past = torch_res.past_key_values
gen_id = torch.argmax(torch_res.last_hidden_state)
print(gen_id, time.time() - t1)
input_ids1 = gen_id.unsqueeze(0)
if loadtype is LoadType.PYTORCH:
tt_model = turbo_transformers.GPT2Model.from_torch(model, test_device)
else:
raise ("LoadType is not supported")
# turbotransformers
for _ in range(32):
t1 = time.time()
print(input_ids2)
res = tt_model(input_ids2) # sequence_output, pooled_output
gen_id = torch.argmax(res[0])
past = res[1]
print(gen_id, time.time() - t1)
input_ids2 = torch.cat([input_ids2, gen_id.unsqueeze(0).unsqueeze(1)], dim=-1)
if name == "main": test(LoadType.PYTORCH, use_cuda=True)
大佬好, 我用gpt2_example.py推理gpt2,生成的第1个token没问题,但把生成的token拼接到前面的序列后继续推理生成的结果就不对了, 比如我的输入是:input_ids = torch.tensor([[12166, 10699, 16752, 4454]], dtype=torch.long).to(testdevice) 推理代码: for in range(32): res = tt_model(input_ids) # sequence_output, pooled_output gen_id = torch.argmax(res[0]) input_ids = torch.cat([input_ids, gen_id.unsqueeze(0).unsqueeze(1)], dim=-1)
生成结果:tensor([[12166, 10699, 16752, 4454, 477, 477, 477, .....]], device='cuda:0') 其中第1个477正确,后面看起来还是用的第1次的输入。 请问这是怎么回事?