VHellendoorn / ICLR20-Great

Data and Code for Reproducing "Global Relational Models of Source Code"
MIT License
83 stars 22 forks source link

A question about the relative scales of word embeddings and position encodings #10

Open jonathan-laurent opened 2 years ago

jonathan-laurent commented 2 years ago

In this implementation, the embedding weights are initialized with a standard deviation of dmodel**-0.5 (see code). This is consistent with many existing Transformer implementations and this means that the embeddings will have a norm of about 1.

However, if I am correct, positional encodings have a greater norm, of the order of dmodel**0.5.

Thus, it is surprising to me that you are adding embeddings with positional encodings without rescaling the former. Indeed, we can see many implementations multiplying embeddings by dmodel**0.5 before adding the positional encodings.

Would you have any comment on this by any chance?

VHellendoorn commented 2 years ago

Thanks for pointing this out. You are right, the positional encoding norm might "drown out" the token embeddings with higher hidden dimensions in the current formulation. I suppose the implementation you reference would address this, though it is surprising to me that we would not rather normalize the positional encodings by a factor of dmodel**-0.5 so as to preserve a norm of <2. Scaling the token embeddings back up seems to essentially annul the benefit of initializing them as we typically do.

Do you have any pointers to work that studies the merit of the various options here? I suppose I could rerun a model or two in order to benchmark whether this change makes a significant difference, although this repository will soon be largely superseded by our more recent work, PLUR -- which I can confirm has the same flaw (positional encodings are added with no apparent scaling here).

jonathan-laurent commented 2 years ago

I guess scaling down the positional encodings would work too. You would just get a smaller signal across the whole network.

Relatedly, I believe the reason many transformer implementations initialize the embedding weights with a standard deviation of dmodel**-0.5 (instead of the default of 1 implemented in most DL framework) is that they share weights between the embedding layer and the final decoder layer. It is natural to initialize the last decoder layer with a dmodel**-0.5 standard deviation because this corresponds to the standard Kaimin initialization.

I have seen transformer implementations do every possible thing here (including scaling up the embeddings even when they are initialized with unit variance) and never found a good discussion of all possible alternatives. Please let me know if you manage to learn more on the topic.