Open jonathan-laurent opened 2 years ago
Thanks for pointing this out. You are right, the positional encoding norm might "drown out" the token embeddings with higher hidden dimensions in the current formulation. I suppose the implementation you reference would address this, though it is surprising to me that we would not rather normalize the positional encodings by a factor of dmodel**-0.5
so as to preserve a norm of <2. Scaling the token embeddings back up seems to essentially annul the benefit of initializing them as we typically do.
Do you have any pointers to work that studies the merit of the various options here? I suppose I could rerun a model or two in order to benchmark whether this change makes a significant difference, although this repository will soon be largely superseded by our more recent work, PLUR -- which I can confirm has the same flaw (positional encodings are added with no apparent scaling here).
I guess scaling down the positional encodings would work too. You would just get a smaller signal across the whole network.
Relatedly, I believe the reason many transformer implementations initialize the embedding weights with a standard deviation of dmodel**-0.5
(instead of the default of 1 implemented in most DL framework) is that they share weights between the embedding layer and the final decoder layer. It is natural to initialize the last decoder layer with a dmodel**-0.5
standard deviation because this corresponds to the standard Kaimin initialization.
I have seen transformer implementations do every possible thing here (including scaling up the embeddings even when they are initialized with unit variance) and never found a good discussion of all possible alternatives. Please let me know if you manage to learn more on the topic.
In this implementation, the embedding weights are initialized with a standard deviation of
dmodel**-0.5
(see code). This is consistent with many existing Transformer implementations and this means that the embeddings will have a norm of about 1.However, if I am correct, positional encodings have a greater norm, of the order of
dmodel**0.5
.Thus, it is surprising to me that you are adding embeddings with positional encodings without rescaling the former. Indeed, we can see many implementations multiplying embeddings by
dmodel**0.5
before adding the positional encodings.Would you have any comment on this by any chance?