lipiji / SongNet

Code for ACL 2020 paper "Rigid Formats Controlled Text Generation":https://www.aclweb.org/anthology/2020.acl-main.68/
MIT License
230 stars 40 forks source link

执行./test 出现错误 "IndexError: The shape of the mask [1] at index 0 does not match " #12

Closed smartmark-pro closed 3 years ago

smartmark-pro commented 3 years ago

李老师你好, 您当前的代码, 我运行没有任何问题, 但是当我把数据迁移到自己搜集的数据时, 会出现错误.

具体报错如下

Traceback (most recent call last):
  File "test.py", line 359, in <module>
    res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s)
  File "test.py", line 61, in top_k_inc
    incremental_state)
  File "/content/SongNet/biglm.py", line 91, in work_incremental
    incremental_state=incremental_state)
  File "/content/SongNet/transformer.py", line 73, in work_incremental
    attn_mask=self_attn_mask, incremental_state=incremental_state)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/content/SongNet/transformer.py", line 156, in forward
    prev_key = prev_key[bidx]
IndexError: The shape of the mask [1] at index 0 does not match the shape of the indexed tensor [2, 12, 1, 64] at index 0

错误情形 train和eval, polish都没有问题, 但是运行test中执行到某条数据时, 就会出现这种错误. 也就是有的数据可以正常预测和打印, 有的不能.

临时的解决办法 我在test中增加try except 跳过执行出错的例子.

希望能解决bug 我读了transformer prev_key前后的代码, 没能理解错误. 如果您在调试中也遇到类似问题, 能给一些解决的提示么?