Closed 0seba closed 5 months ago
I think you should be dividing by the scale in the following line
https://github.com/kyegomez/zeta/blob/7dbb6a62f83413977a922d5fc6dec1b11f734bc3/zeta/nn/modules/rms_norm.py#L35
This this the scale definition
https://github.com/kyegomez/zeta/blob/7dbb6a62f83413977a922d5fc6dec1b11f734bc3/zeta/nn/modules/rms_norm.py#L29C9-L29C31
self.scale = dim**-0.5
And RMSNorm formula
Edit:
Also, I think the normalization should be in the dim -1, not -2
https://github.com/kyegomez/zeta/blob/7dbb6a62f83413977a922d5fc6dec1b11f734bc3/zeta/nn/modules/rms_norm.py#L34
Hello there, thank you for opening an Issue ! 🙏🏻 The team was notified and they will get back to you asap.
Stale issue message
I think you should be dividing by the scale in the following line
https://github.com/kyegomez/zeta/blob/7dbb6a62f83413977a922d5fc6dec1b11f734bc3/zeta/nn/modules/rms_norm.py#L35
This this the scale definition
https://github.com/kyegomez/zeta/blob/7dbb6a62f83413977a922d5fc6dec1b11f734bc3/zeta/nn/modules/rms_norm.py#L29C9-L29C31
And RMSNorm formula
Edit:
Also, I think the normalization should be in the dim -1, not -2
https://github.com/kyegomez/zeta/blob/7dbb6a62f83413977a922d5fc6dec1b11f734bc3/zeta/nn/modules/rms_norm.py#L34
Upvote & Fund