ageron / handson-ml2

A series of Jupyter notebooks that walk you through the fundamentals of Machine Learning and Deep Learning in Python using Scikit-Learn, Keras and TensorFlow 2.
Apache License 2.0
27.26k stars 12.6k forks source link

[BUG] Possible Typo in chapter 16, Positional Embeddings #546

Closed kasrahabib closed 2 years ago

kasrahabib commented 2 years ago

Hi @ageron,

I was reading this part of the text from chapter 16, positional embeddings. If you read the last sentence, p = 25, can that be 22 as I feel a relation between 22 and 35.

“Moreover, the choice of oscillating functions (sine and cosine) makes it possible for the model to learn relative positions as well. For example, words located 38 words apart (e.g., at positions p = 22 and p = 60) always have the same positional embedding values in the embedding dimensions i = 100 and i = 101, as you can see in Figure 16-9. This explains why we need both the sine and the cosine for each frequency: if we only used the sine (the blue wave at i = 100), the model would not be able to distinguish positions p = 25 and p = 35 (marked by a cross)."

Fig 16-9:

Screenshot 2022-03-10 at 18 57 17

Furthermore, I think I mixed up the relative positional embedding part (this: “it is possible for the model to learn relative positions as well”) with this (self-atention with relative position embeddings) paper. Here you are referring to the fact that the 100 and 101 dimensions of the tokens are added by the same amount from 100 and 101 dimensions of the positional embedding. And this results in shifting closer those dimensions of the tokens in space and this somehow encodes the relations between the token, not the paper self-atention with relative position embeddings.

ageron commented 2 years ago

Hi @kasrahabib , Thanks for your feedback, that's much appreciated! 👍 You are right about the typo, it should have been p = 22 instead of p = 25, sorry about that. This was brought to my attention a while back so it's already fixed in the most recent releases. I'm not sure I understand your second point: is there an error in the text or is there something unclear?

kasrahabib commented 2 years ago

Hi @ageron,

Thanks for quick reply. I did not know that it is already reported.

About the second point: no, there is nothing wrong with the text.

Can you please give some more details about this (“it is possible for the model to learn relative positions as well”) part of the text, how having the same positional values on the same dimension can be interpreted as encoding relative positions?

Best, Kasra

ageron commented 2 years ago

Ah got it. Ok, so let's go back to the definition of the positional encoding P, for the word at position pth and for the dimensions 2 * i and 2 * i + 1:

P(p, 2 * i) = sin(p / 10_000 ** (2 * i / max_dims))
P(p, 2 * i + 1) = cos(p / 10_000 ** (2 * i / max_dims))

Let's define f(i) = 1 / 10_000 ** (2 * i / max_dims). The equations simplify to:

P(p, 2 * i) = sin(p * f(i))
P(p, 2 * i + 1) = cos(p * f(i))

Now suppose two words are often n words apart in the training set. That's their most frequent relative position. For example, they may often be 7 words apart.

Let's consider the value of j such that f(j) is as close as possible to 2π / n.

A bit of algebra leads to j ≈ max_dims * log(n / 2π) / log(10_000) / 2. For example, if n is 7, then j is 3.

Now we can show that the positional encoding for both words are approximately equal in dimensions 2 * j and 2 * j + 1. Indeed:

P(p + n, 2 * j) = sin((p + n) * f(j)) = sin(p * f(j) + n * f(j)) ≈ sin(p * f(j) + n * 2π / n) = sin(p * f(j) + 2π) = sin(p * f(j)) = P(p, 2 * j)

and

P(p + n, 2 * j + 1) = cos((p + n) * f(j)) = cos(p * f(j) + n * f(j)) ≈ cos(p * f(j) + n * 2π / n) = cos(p * f(j) + 2π) = cos(p * f(j)) = P(p, 2 * j + 1)

So it's easy for the model to detect that two words are 7 words apart: they have roughly equal position encodings in dimensions 6 and 7 (that's 2 3 and 2 3 + 1).

Perhaps the model detects relative positions differently, but at least this proof shows that in principle the information about relative positions is fairly easily accessible to the model. There are several caveats, like the fact that I assumed that there exists a value of j such that f(j) is a good enough approximation of 2π / n, and I also didn't explain how the model can distinguish words that are 7 words apart, versus 14 words apart, or 21 words apart, since they all have the same positional encoding in the 6th and 7th dimensions (but the other dimensions differ).

But hopefully this convinces you that the positional encodings allow the model to detect relative positions between words easily.

kasrahabib commented 2 years ago

Hi,

Thank you a very much for the detailed explanation. The mathematical explanation is very convincing, this helped me a lot. I am still thinking about how the model might interpret these relations.

One thing, I can think/imagine how the model might detect the relative positions 😬 is the model might detect j as relation when f(j) is as close as possible to 2π / n:

When two word (e.g., let's say the ones which are 7 words apart) have roughly equal position encoding in dimensions 6 and 7 (that's 2 × 3 and 2 × 3 + 1). Then adding this positional embeddings to the word embeddings of words' move the two words' 6th and 7th dimensions closer to each other or the same direction (I am imagining it visualized in a d-dimensional vector space) and this can be interpreted by the model as some sort of relation between them as they might end up closer to each other in 6 and 7 dimensions.

Best 😊, Kasra

ageron commented 2 years ago

I'm glad this helped! I like your point about adding positional embedding, indeed, I think this makes sense, this may well be what the models learn. But to be honest, although neural nets perform "simple" matrix multiplications and "simple" non-linearities, they do so in so many steps and in so many dimensions that it's really hard to know for sure what patterns they really detected in the data. This is a whole field of research: interpretability.

kasrahabib commented 2 years ago

Hi,

Yes, I know it is hard to interpret it. And what I said is not actually how it interprets. Just away to convince myself. As everything is clear for me now, I will close this issue.