Closed larspars closed 7 years ago
Yes, it's quite easy to introduce a bug in the masking process and it would lead to misleadingly high results. Some possibilities (from my experience writing a Transformer implementation) include softmax over the wrong dimension and the use of the wrong kind of triangular matrix. Here's an implementation of optionally masked attention I know to be correct that works for both training and inference (but depends on broadcasting/PyTorch master):
# F.softmax has strange default behavior, normalizing over dim 0 for 3D inputs
def softmax(x):
if x.dim() == 3:
return F.softmax(x.transpose(0, 2)).transpose(0, 2)
return F.softmax(x)
# torch.matmul can't do (4, 3, 2) @ (4, 2) -> (4, 3)
def matmul(x, y):
if x.dim() == y.dim():
return x @ y
if x.dim() == y.dim() - 1:
return (x.unsqueeze(-2) @ y).squeeze(-2)
return (x @ y.unsqueeze(-2)).squeeze(-2)
class Attention(nn.Module):
def __init__(self, d_key, drop_ratio, causal):
super().__init__()
self.scale = math.sqrt(d_key)
self.dropout = nn.Dropout(drop_ratio)
self.causal = causal
def forward(self, query, key, value):
dot_products = matmul(query, key.transpose(1, 2))
if query.dim() == 3 and self.causal:
tri = key.data.new(key.size(1), key.size(1)).fill_(1).triu(1) * INF
dot_products.data.sub_(tri.unsqueeze(0))
return matmul(self.dropout(softmax(dot_products / self.scale)), value)
I think the reason why you're getting such a high accuracy is because the evaluation function uses teacher enforcing.
https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/train.py#L92
Teacher forcing is when you supply the ground-truth output for timestep t-1 as inputs in timestep t. That shouldn't have this effect on accuracy.
I know that, I'm saying that the accuracy metric is calculated using teacher forcing. As in the ground truth is supplied on the evaluation set, but isn't during translation, which could be a reason behind the why the metrics are so good but the results are so bad
98% accuracy on a dev set is still way too high for teacher forcing.
Gotcha @kendricktan, I see what you mean then. But consider a skilled human translator, they won't get 98% accuracy per word, even if they are told the previous translated word in the sentence. There's too many ways to express the same thing.
This paper reports a 4.78 perplexity on the multi30k dataset with a RNN seq2seq model: https://www.degruyter.com/downloadpdf/j/pralin.2017.108.issue-1/pralin-2017-0020/pralin-2017-0020.pdf
My crippled 1-layer model quickly gets a perplexity of 1.1. That's too good to be true :)
@larspars fair point. :+1:
@jekbradbury what version of pytorch have you tested that Attention module on?
It was written against PyTorch master as of three weeks ago, so it should work on 0.2.0 (to be released this week).
Hi all, Sorry for the absence so far. I am busy recently.
Thanks for all these discussion. I think there does exist some bugs in my code. The accuracy is way too high. As you point out, it is very likely the current mask may somehow leak the future information, but I am not sure why and try to find out. Therefore, any further insight or observation or discussion are very welcome. Thanks!
Yu-Hsiang
As I have pulled a request, I think your LayerNormalization module implementation is wrong, you've normalized at the wrong dimension(time) which leaks the future information: https://github.com/jadore801120/attention-is-all-you-need-pytorch/pull/17
@ZiJianZhao I think that might be it, I'm getting much more realistic accuracy and perplexity after your patch.
@ZiJianZhao , I think you are right. I have just merged it! Thanks for the pointing out!
Yu-Hsiang
I get 98% accuracy after 10 epochs on the multi30k validation set using this 1-layer model:
This is a very small model (note -d_inner_hid 1), which should not get good results at all (98% accuracy is way too high in any case). Generating translations with translate.py produces non-sense. This makes me suspect that there is a problem with the masking code that allows the model to 'cheat' by looking at the target sequence.
I haven't been able to figure out where the problem is, but something seems wrong.