lukemelas / PyTorch-Pretrained-ViT

Vision Transformer (ViT) in PyTorch
770 stars 124 forks source link

Multi head uses just one set of Q, K, V? #18

Open daviduarte opened 3 years ago

daviduarte commented 3 years ago

In transformer.py, in class MultiHeadedSelfAttention() we have the var declaration:

  self.proj_q = nn.Linear(dim, dim)
  self.proj_k = nn.Linear(dim, dim)
  self.proj_v = nn.Linear(dim, dim)

but wasn't suposed to be Q, K and V an independent trainable matrix per head? E.g. if num_head = 12, wasn't that suposed to be like:

set = []
for i in range(12):
    set.append([nn.Linear(dim, dim), nn.Linear(dim, dim), nn.Linear(dim, dim)])

Regards!