microsoft / ProphetNet

A research project for natural language generation, containing the official implementations by MSRA NLC team.
MIT License
692 stars 110 forks source link

Variables needed for gradient computation should be modified by an inplace operation. #38

Open aleversn opened 3 years ago

aleversn commented 3 years ago

In the README.md, it suggests using the torch of version 1.3.0, but there seems no that version in the previous version of PyTorch, link.

So, I use the latest version (1.7.1) of the torch, and when I start training, I got this Runtime Error. 1 And then I found that the error was caused in the prophetnet/ngram_multihead_attention.py line 255.

q = q * self.scaling

It looks like this operation is not allowed anymore, then I fixed the problem by the following:

q_ = q * self.scaling

if self.bias_k is not None:
    assert self.bias_v is not None
    k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
    v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
    q = q_.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)