Open ParadoxZW opened 5 years ago
Hi,
Yes, batchnorm is best. This parameter is the only one in the default hyper-parameters which is not the best one. The reason for that is that batchnorm is not always possible. It requires to have no padding in batches, which is only possible for BERT or language modeling training, but not for MT (since we have padded sentences in MT). As a result, it's default to False. If you set it to True you will have the good hyper-parameters configuration.
This assert checks that there is no issue: https://github.com/facebookresearch/XLM/blob/master/src/trainer.py#L290
To use batches that contain padding, and the memory with batchnorm, we would need to do a few modifications in the code to be sure that hidden states that do not correspond to padding are not included in the mean/variance computation of the batchnorm layer.
This my implementation of batchnorm with paddings
class BatchNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.dim = dim
self.momentum = 0.1
self.a_2 = nn.Parameter(torch.ones(dim))
self.b_2 = nn.Parameter(torch.zeros(dim))
self.register_buffer('running_mean', torch.zeros(dim))
self.register_buffer('running_var', torch.ones(dim))
def forward(self, x, x_mask):
'''
x: (batchsize, n_tokens, feature_dim)
x_mask: (batchsize, 1, 1, n_tokens)
'''
self.sum_mask = (1 - x_mask).sum()
mean = self._mean(x, x_mask)
dif = x - mean
sq = torch.mul(dif, dif)
var = self._mean(sq, x_mask)
if self.training:
self.running_mean = (1 - self.momentum) * self.running_mean \
+ self.momentum * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var \
+ self.momentum * var.detach()
std = torch.sqrt(self.running_var)
return self.a_2 * (x - self.running_mean) / (std + self.eps) + self.b_2
def _mean(self, x, x_mask):
x_ = x.masked_fill(x_mask.squeeze(1).squeeze(1).unsqueeze(2), 0)
summ = x_.view(-1, self.dim).sum(0)
return summ / self.sum_mask
But it seems that this implementation didn't work according my experiment results. Is there something wrong?
I mean, not errors, the code can run. But the KL became larger than directly using traditional batchnorm. So I think maybe I have missunderstood something
Did you make the comparison between this implementation and the regular batchnorm in a setting where there is no padding? Like MLM or CLM training?
The comparison is between experiments both of tasks with padding (I also change the scripts to only calculate the KL divergence for features that do not correspond to padding). And the KL do decrease. I made the mistake because I only observe the KL and usage in early epochs, which are really bad (uasge=0.6, KL=4). But they became better when training going on. Eventually, usage was about 0.99 but KL was about 1.5. This KL is better than the case of using traditional batchnorm for the same task but still not ideal compared with the KL result reported in your PKM paper. So if my batchnorm implementation is correct, maybe the problem was in my training schedule (Then I will do more experiments to find a better one). Another question is, did you observe the dynamic of KL and usage during training? How KL changes in different epochs? Very large at early epochs and gradually decrease to a better one?
Yes, this is usually what happens. Initial KL is high, but decreases slowly over time as the model learns to make a better usage of the memory.
And the usage decreases at the beginning, then increases?
It seems that the default setting in your implementation is not the best. For example,
mem_query_batchnorm
is False by default. According to your paper, using batchnorm is better. Or your further experiment shows that PKM without batchnorm could be better, so you change the default value ofmem_query_batchnorm
? If not, I would like to know other optimal settings that different with default and lead to a better performance.