Open SangriaXinji opened 4 years ago
这个错误的原因是因为embedding的时候,输入要求的tensor类型的long类型,代码中由于没有指定,所以需要修改。
修改如下:
在DataLoader类的 __next__
方法中,修改成如下代码:inst_data_tensor = Variable(torch.from_numpy(inst_data)).long()
将迭代生成的feature改成long tensor类型。
label不需要embedding,因此不需要修改。
我也是这个报错 请问您解决了吗