YunwenTechnology / Unilm

438 stars 87 forks source link

RuntimeError: gather(): Expected dtype int64 for index #30

Open potong opened 1 year ago

potong commented 1 year ago

尝试的一种解决办法,有效! image

在modeling_unilm.py文件中step_back_ptrs.append(back_ptrs)前面加入一行代码back_ptrs = back_ptrs.type(torch.int64)