lucidrains / vit-pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
MIT License
19.99k stars 3k forks source link

Quesiton about attention's qkv matrix #237

Open JearSYY opened 2 years ago

JearSYY commented 2 years ago

Hello thanks for this great repo!

I am confused about the details in the vit.py. In the attention's section, when compute the q, k, v matrix, you project x from ( b, n, d ) to ( b, n, 3d) and then split x to 3 parts using chunk() like:

qkv = self.to_qkv(x).chunk(3, dim = -1)

However, I think in this way q, k and v matrix only contains part of the information of the original matrix x, which is not the exact meaning of the transformer paper. In the original paper, q, k, v contains all the information of the input matrix, and then perform dot production to compute attentions. Please check xD.

PS: I am a beginner in this topic, if I have any misunderstanding, please figure it out and sorry for any possible inconveniece.

YGwhere commented 11 months ago

You may want to take a look at the linear layer in self.to_qkv(x). The output of the linear layer is set to inner_dim*3. The author merged the weights of qkv into a linear layer. So the chunk(3, dim=-1) decomposes qkv into a list.

PS: I am also a beginner, so I don’t know if this is the right way to understand it. 🤓

chengengliu commented 9 months ago

It is just a shortcut to avoid repeating such code:self.to_q = nn.Linear(dim, inner_dim, bias = bias) three times(for qkv you need three linear mappings). Of course there are other implementations that are written in the above styles, but the idea of building QKV is the same.