facebookresearch / XLM

PyTorch original implementation of Cross-lingual Language Model Pretraining.
Other
2.89k stars 497 forks source link

What is the best setting of hyperparameters of PKM? #171

Open ParadoxZW opened 5 years ago

ParadoxZW commented 5 years ago

It seems that the default setting in your implementation is not the best. For example, mem_query_batchnorm is False by default. According to your paper, using batchnorm is better. Or your further experiment shows that PKM without batchnorm could be better, so you change the default value of mem_query_batchnorm ? If not, I would like to know other optimal settings that different with default and lead to a better performance.

glample commented 5 years ago

Hi,

Yes, batchnorm is best. This parameter is the only one in the default hyper-parameters which is not the best one. The reason for that is that batchnorm is not always possible. It requires to have no padding in batches, which is only possible for BERT or language modeling training, but not for MT (since we have padded sentences in MT). As a result, it's default to False. If you set it to True you will have the good hyper-parameters configuration.

This assert checks that there is no issue: https://github.com/facebookresearch/XLM/blob/master/src/trainer.py#L290

To use batches that contain padding, and the memory with batchnorm, we would need to do a few modifications in the code to be sure that hidden states that do not correspond to padding are not included in the mean/variance computation of the batchnorm layer.

ParadoxZW commented 5 years ago

This my implementation of batchnorm with paddings

class BatchNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.dim = dim
        self.momentum = 0.1
        self.a_2 = nn.Parameter(torch.ones(dim))
        self.b_2 = nn.Parameter(torch.zeros(dim))
        self.register_buffer('running_mean', torch.zeros(dim))
        self.register_buffer('running_var', torch.ones(dim))

    def forward(self, x, x_mask):
        '''
          x: (batchsize, n_tokens, feature_dim)
          x_mask: (batchsize, 1, 1, n_tokens)
        '''
        self.sum_mask = (1 - x_mask).sum()

        mean = self._mean(x, x_mask)
        dif = x - mean
        sq = torch.mul(dif, dif)
        var = self._mean(sq, x_mask)
        if self.training:
            self.running_mean = (1 - self.momentum) * self.running_mean \
                                + self.momentum * mean.detach()
            self.running_var = (1 - self.momentum) * self.running_var \
                                + self.momentum * var.detach()
        std = torch.sqrt(self.running_var)

        return self.a_2 * (x - self.running_mean) / (std + self.eps) + self.b_2

    def _mean(self, x, x_mask):
        x_ = x.masked_fill(x_mask.squeeze(1).squeeze(1).unsqueeze(2), 0)
        summ = x_.view(-1, self.dim).sum(0)
        return summ / self.sum_mask

But it seems that this implementation didn't work according my experiment results. Is there something wrong?

ParadoxZW commented 5 years ago

I mean, not errors, the code can run. But the KL became larger than directly using traditional batchnorm. So I think maybe I have missunderstood something

glample commented 5 years ago

Did you make the comparison between this implementation and the regular batchnorm in a setting where there is no padding? Like MLM or CLM training?

ParadoxZW commented 5 years ago

The comparison is between experiments both of tasks with padding (I also change the scripts to only calculate the KL divergence for features that do not correspond to padding). And the KL do decrease. I made the mistake because I only observe the KL and usage in early epochs, which are really bad (uasge=0.6, KL=4). But they became better when training going on. Eventually, usage was about 0.99 but KL was about 1.5. This KL is better than the case of using traditional batchnorm for the same task but still not ideal compared with the KL result reported in your PKM paper. So if my batchnorm implementation is correct, maybe the problem was in my training schedule (Then I will do more experiments to find a better one). Another question is, did you observe the dynamic of KL and usage during training? How KL changes in different epochs? Very large at early epochs and gradually decrease to a better one?

glample commented 5 years ago

Yes, this is usually what happens. Initial KL is high, but decreases slowly over time as the model learns to make a better usage of the memory.

ParadoxZW commented 5 years ago

And the usage decreases at the beginning, then increases?