codertimo / BERT-pytorch

Google AI 2018 BERT pytorch implementation
Apache License 2.0
6.11k stars 1.29k forks source link

Attention maybe changed #46

Closed soobinseo closed 5 years ago

soobinseo commented 5 years ago

Hi, Thanks for your great job. I wonder that the attention mechanism of your code seems to be changed. The shape of attention vector should be (batch, timestep, timestep), but according to your code, the shape of self attention vector is (batch, timestep, hidden_size). There is new code that I fixed below. Please review it and appreciate your comments. Thank you.

` class Attention(nn.Module): def init(self, num_hidden, h=8): super(Attention, self).init()

    self.num_hidden_per_attn = num_hidden // h
    self.h = h

    self.key = nn.Linear(num_hidden, num_hidden)
    self.value = nn.Linear(num_hidden, num_hidden)
    self.query = nn.Linear(num_hidden, num_hidden)

    self.layer_norm_1 = LayerNorm(num_hidden)
    self.layer_norm_2 = LayerNorm(num_hidden)
    self.out_linear = nn.Linear(num_hidden, num_hidden)

    self.dropout = nn.Dropout(p=0.1)

def forward(self, input_):
    batch_size = input_.size(0)

    key = F.relu(self.key(input_))
    value = F.relu(self.value(input_))
    query = F.relu(self.query(input_))

    key, value, query = list(map(lambda x: x.view(batch_size, -1, self.h, self.num_hidden_per_attn), (key, value, query)))
    params = [(key[:,:,i,:], value[:,:,i,:], query[:,:,i,:]) for i in range(self.h)]

    _attn = list(map(self._multihead, params))
    attn = list(map(lambda x: x[0], _attn))
    probs = list(map(lambda x: x[1], _attn))
    result = t.cat(attn, -1)

    result = self.dropout(result)
    result = result.view(batch_size, -1, self.h * self.num_hidden_per_attn)

    # residual connection
    result = self.layer_norm_1(F.relu(input_ + result))

    out = self.out_linear(result)
    out = self.layer_norm_2(F.relu(result + out))

    return result, probs

def _multihead(self, params):

    key, value, query = params[0], params[1], params[2]

    attn = t.bmm(query, key.transpose(1,2)) / math.sqrt(key.shape[-1])

    attn = F.softmax(attn, dim=-1)
    result = t.bmm(attn, value)

    return result, attn

`