codertimo / BERT-pytorch

Google AI 2018 BERT pytorch implementation
Apache License 2.0
6.09k stars 1.29k forks source link

self.d_k = d_model // h gives 64 dimension ? #60

Open BerenLuthien opened 5 years ago

BerenLuthien commented 5 years ago

https://github.com/codertimo/BERT-pytorch/blob/d10dc4f9d5a6f2ca74380f62039526eb7277c671/bert_pytorch/model/attention/multi_head.py#L15

Looks that self.d_k = d_model // h ---> embed size 768 dividing number of heads 12 = 64

        self.d_k = d_model // h # 64
        self.h = h # 12 heads
        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]

why convert 768 dimensional [q,v,k] into 64 dimension embedding ?

Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html I put some comments on the shape:

class MultiHeadedAttention(nn.Module): # d_model=512, h=8
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h # 512//8=64
        self.h = h # 8
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]

        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k) # (nbatches, -1, 512)
        return self.linears[-1](x)
Vesauza commented 4 years ago

MultiHead(Q, K, V ) = Concat(head1, ..., headh)WO,In this work we employ h = 8 parallel attention layers, or heads. For each of these we use dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

----from Attention Is All You Need