Open sh0416 opened 1 year ago
I think it prevents the logsumexp from being too large.. Is there any reference for doing this? or just a practical issue? I derived it to the following equation..
x_i - logsumexp(x)(1-0.0001*logsumexp(x))
In the mentioned paper (https://arxiv.org/abs/2206.13517), regularization term is added to prevent the divergence. Is it right? @enijkamp Can I get further details about the loss term?
I think it is intended for numerical stability, but I don't know how it works.
Could you explain it or provide a reference for that code?
https://github.com/salesforce/jaxformer/blob/9c41fd44a1b052301ad216e2c8fe20677c77adf7/jaxformer/models/decoder/inter/model.py#L71