Open Tzeviya opened 3 years ago
Could be! Thanks for your report, I'll check. In the meanwhile, you could run the code under with torch.autograd.detect_anomaly():
, maybe that'd indicate the problem.
Thanks for the quick reply! So far, all I get is:
Traceback (most recent call last): File "example.py", line 44, in
custom_ctc_grad, = torch.autograd.grad(custom_ctc.sum(), logits, retain_graph = True) File "/home/usr/me/anaconda3/lib/python3.7/site-packages/torch/autograd/init.py", line 204, in grad inputs, allow_unused) RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [256, 122]], which is output 0 of SliceBackward, is at version 130; expected version 129 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Which I assume implies a problem with the slicing...
For now I just randomly added .clone()
to various places, e.g. : log_probs[0, :, blank].clone()
, and log_probs_[t].clone()
etc. I don't know if it's correct to do this but the output seems fine:
Custom loss matches: True
Grad matches: True
CE grad matches: True
@Tzeviya have you tried to do double-backward here? or just run the example.py
as is?
as simply running example.py
works on recent pytorch
i met the same problem,
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [256, 122]], which is output 0 of AsStridedBackward0, is at version 130; expected version 129 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
Could you please post a complete repro? Are you just trying to run python example.py
or something else? What's your PyTorch version?
Hi,
When running your code I received the following error:
Could it be that there is some problem with the slicing in the
ctc_loss
function?Thanks :)