allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.77k stars 2.25k forks source link

Question about additional linear layer in BertPooler #3559

Closed cemilcengiz closed 4 years ago

cemilcengiz commented 4 years ago

Hi, while I was using AllenNLP's BertForClassification model (powered with Bert-base) on MNLI dataset, I realized it gets slightly better accuracy on the development set compared to the published results on the official paper. While investigating for possible reasons, I noticed that BertForClassification uses BertPooler class to pool the BERT encoder output before the final layer, i.e. the classifier. The interesting is that the forward() method of BertPooler contains a linear layer itself. Therefore, it means we effectively pass the [CLS] token through two-layer MLP instead of a single linear projection layer as done in the paper. I wonder why the BertPooler is implemented this way. If my understanding is true, it is impossible to compare our results with the official ones. Please correct me if I am wrong. Here is the BertPooler class:

class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

Also, you can see the relevant part from the paper (from the 4.1 GLUE section) image

sai-prasanna commented 4 years ago

Yeah it was an oversight that we didn't mention it in the paper (we'll mention it in the updated version), but we have an extra projection layer for the classifier and LM before feeding it into the classification.

However, these layers are both pre-trained with the rest of the network and are included in the pre-trained checkpoint. So the part about "the only new parameters added during fine-tuning" is correct, it's just not correct to say "output of the Transformer", it's really "output of the Transformer fed through one additional non-linear transformation".

The tanh() thing was done early to try to make it more interpretable but it probably doesn't matter either way.

https://github.com/google-research/bert/issues/43

schmmd commented 4 years ago

Closing due to @sai-prasanna 's excellent answer.