meta-llama / llama-models

Utilities intended for use with Llama models.
Other
4.88k stars 838 forks source link

Function does not implement RMSNorm #205

Closed JW-swansea closed 3 weeks ago

JW-swansea commented 3 weeks ago

Hi, I was looking through the code and noticed something strange.

This function, is supposed to implement RMSNorm, from Zhang, Biao, and Rico Sennrich. "Root mean square layer normalization." Advances in Neural Information Processing Systems 32 (2019).

But instead of dividing by the appropriate coefficient, it multiplies.

https://github.com/meta-llama/llama-models/blob/2fe1a1690162910660332e3294a552cf0ec7e754/models/llama3/reference_impl/model.py#L31-L42

If the square of entries of the vector is already n, this makes no difference, but if it is anything else, it will make larger vectors larger and smaller vectors smaller, away from that value, opposite to intended functionality.

(previously marked as an issue here in the depreciated repository)

cglagovichTT commented 3 weeks ago

image

RMSNorm should divide the input by the RMS. Since torch.rsqrt gives the reciprocal of the sqrt, Meta's implementation of RMSNorm should be correct. https://pytorch.org/docs/stable/generated/torch.rsqrt.html

JW-swansea commented 3 weeks ago

Since torch.rsqrt gives the reciprocal of the sqrt

There's my error, rsqrt not sqrt, thanks.