k2-fsa / k2

FSA/FST algorithms, differentiable, with PyTorch compatibility.
https://k2-fsa.github.io/k2
Apache License 2.0
1.11k stars 213 forks source link

Removing inplace grad multiplication to allow `retain_graph` in backprop #1106

Closed Tomiinek closed 1 year ago

Tomiinek commented 1 year ago

I noticed that it is not possible to use loss.backward(retain_graph=True) with any of the rnn-t losses (which is useful when training with multiple optimizers). It fails because of the in-place multiplication on gradients, saying:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2, 8, 17]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

This PR fixes the issue.

csukuangfj commented 1 year ago

Thanks. Looks good to me