pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82k stars 21.99k forks source link

Derivative for _ctc_loss_backward #70108

Open grazder opened 2 years ago

grazder commented 2 years ago

🚀 The feature, motivation and pitch

Discussion motivation: https://discuss.pytorch.org/t/higher-order-gradients-of-ctcloss/35019

Hi! I'm trying to do double backward for CTC loss function and I can't do it due to Runtime Error: derivative for _ctc_loss_backward is not implemented error.

I'm implementing https://arxiv.org/abs/2102.08098 paper for a network with CTC loss. So I can't do it in a natural way with default torch.nn.functional.ctc_loss.

GradInit calculates gradient norm and penalizes network for high gradient norms. Code fails in this place: https://github.com/zhuchen03/gradinit/blob/fdaaab15d796864719b71ae6b358a99aeab3ec17/gradinit_utils.py#L207

Alternatives

As an alternatives, I'm considering using different losses or hand-written ctc loss realization which are not optimal for our network, so it can cause not optimal optimization process.

Additional context

No response

cc @ezyang @gchanan @zou3519 @bdhirsh @albanD @gqchen @pearu @nikitaved @soulitzer @Lezcano @Varal7

albanD commented 2 years ago

Hi,

Thanks you for the feature request. We would definitely be happy to accept a PR adding this!

vadimkantorov commented 2 years ago

@albanD does functorch by chance have a finite difference approximation? e.g. it may be nice to have a standardized finite differences for second order grads / double backward (maybe already implemented in gradgradcheck?)

albanD commented 2 years ago

gradcheck/gradgradcheck can be used to check this yes. functorch is based on the formulas in pytorch so we cannot really use it for checking.

vadimkantorov commented 2 years ago

we cannot really use it for checking.

I was talking not about checking per se, but about surfacing out the finite differences code that exists in gradgradcheck as public API

If I understand well, it finite differences may be used for some ops that lack manual double backward

albanD commented 2 years ago

If I understand well, it finite differences may be used for some ops that lack manual double backward

In theory yes, in practice not really. Finite differencing require computations to run in double precision to be effective. And has a lot of caveats wrt to re-running the user function many times (like randomness, etc).

vadimkantorov commented 2 years ago

I agree, but for some scenarios it may still be useful with all limitations - maybe even to validate other dependent code before writing a custom double backward.

ZR-HH commented 2 years ago

I also meet the same error, when I used ctc_loss for meta learning algorithm, the error appeared. But if I used the kiv_loss, there was not anything wrong. If you have some suggest for my questions? Thanks

albanD commented 2 years ago

bumping priority due to activity.

Stanwang1210 commented 7 months ago

Have anyone solve this issues ?

vadimkantorov commented 7 months ago

One can reformulate CTC as first computing the optimal alignment path and then calculating cross entropy, as in https://github.com/vadimkantorov/ctc

Then it should be able to do double-backward through that cross-entropy calculation

But I'm not sure if this is the correct gradient for double-backward in general

Stanwang1210 commented 7 months ago

@vadimkantorov Oh! That's amazing! I really appreciate that!

But is there any way that we can easily integrate this code with the original CTC? I think maybe we can directly replace the F.ctc_loss in nn.CTC_Loss like you did in the example ? Can you provide some suggestion to use it?

vadimkantorov commented 7 months ago

You can find in this example the cross-entropy expression using the calculated targets and double-backward through that (keeping the targets constant). But again, I'm not sure if this approximation of the double-gradient is accurate

Stanwang1210 commented 7 months ago

@vadimkantorov Thanks for you help again!

I fix the in-place operation in ctc_loss myself with the suggestion from @Tzeviya However, I'm wondering whether you have solve the alignment error in your issues? If yes, please let me know

vadimkantorov commented 1 month ago

@Stanwang1210 yep, fixed the alignment issue in my repo. regarding your in-place fixes, have you tried to do double-backward? (the vanilla example.py version in my repo runs fine now)