salesforce / jaxformer

Minimal library to train LLMs on TPU in JAX with pjit().
BSD 3-Clause "New" or "Revised" License
277 stars 37 forks source link

Log Z term in loss #26

Open sh0416 opened 1 year ago

sh0416 commented 1 year ago

I think it is intended for numerical stability, but I don't know how it works.

Could you explain it or provide a reference for that code?

https://github.com/salesforce/jaxformer/blob/9c41fd44a1b052301ad216e2c8fe20677c77adf7/jaxformer/models/decoder/inter/model.py#L71

sh0416 commented 1 year ago

I think it prevents the logsumexp from being too large.. Is there any reference for doing this? or just a practical issue? I derived it to the following equation..

x_i - logsumexp(x)(1-0.0001*logsumexp(x))

sh0416 commented 1 year ago

In the mentioned paper (https://arxiv.org/abs/2206.13517), regularization term is added to prevent the divergence. Is it right? @enijkamp Can I get further details about the loss term?