google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.19k stars 492 forks source link

[Question] Embeddings normalization by sqrt(hidden_size) #29

Closed Andrei-Aksionov closed 4 months ago

Andrei-Aksionov commented 4 months ago

Hello there 👋

Thanks for the repo. But I have one question: why do we need to scale up (normalize) token embeddings? https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L431-L432

Unfortunately, I cannot find an answer anywhere.

ghost commented 4 months ago

https://arxiv.org/pdf/1706.03762.pdf See 3.4

Andrei-Aksionov commented 4 months ago

Thanks @crolequi for the response. In the paper it just stated that they multiplied weights, but didn't explain why exactly.

Screenshot 2024-03-05 at 12 39 29 PM
suryabhupa commented 4 months ago

that's a great question! it's been asked a few times, and there are some possible explanations (but no clear reason): https://datascience.stackexchange.com/questions/87906/transformer-model-why-are-word-embeddings-scaled-before-adding-positional-encod.

Andrei-Aksionov commented 4 months ago

Thanks @suryabhupa for the link. It helped a lot.