renmada / t5-pegasus-pytorch

402 stars 61 forks source link

generate的时候报错 #9

Closed shawroad closed 3 years ago

shawroad commented 3 years ago

File "train_T5.py", line 151, in input_ids=q_id, attention_mask=q_mask.type(torch.uint8)) File "/home/jl-wzy/anaconda3/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 43, in decorate_no_grad return func(*args, kwargs) File "/home/jl-wzy/anaconda3/lib/python3.6/site-packages/transformers/generation_utils.py", line 1050, in generate model_kwargs, File "/home/jl-wzy/anaconda3/lib/python3.6/site-packages/transformers/generation_utils.py", line 2228, in group_beam_search if beam_scorer.is_done: File "/home/jl-wzy/anaconda3/lib/python3.6/site-packages/transformers/generation_beam_search.py", line 197, in is_done return self._done.all() RuntimeError: all only supports torch.uint8 dtype

renmada commented 3 years ago

transformers版本对的上吗

shawroad commented 3 years ago

你的tranformers版本是多少? 我的版本用的是'4.3.3'

renmada commented 3 years ago

没遇到过这个问题,无能为力