jadore801120 / attention-is-all-you-need-pytorch

A PyTorch implementation of the Transformer model in "Attention is All You Need".
MIT License
8.82k stars 1.98k forks source link

Masking bug? #19

Closed larspars closed 7 years ago

larspars commented 7 years ago

I get 98% accuracy after 10 epochs on the multi30k validation set using this 1-layer model:

python train.py -data data/multi30k.atok.low.pt -save_model trained -save_mode best -proj_share_weight -dropout 0.0 -n_layers 1 -n_warmup_steps 40 -epoch 50 -d_inner_hid 1 -d_model 128 -d_word_vec 128 -n_head 4

This is a very small model (note -d_inner_hid 1), which should not get good results at all (98% accuracy is way too high in any case). Generating translations with translate.py produces non-sense. This makes me suspect that there is a problem with the masking code that allows the model to 'cheat' by looking at the target sequence.

I haven't been able to figure out where the problem is, but something seems wrong.

jekbradbury commented 7 years ago

Yes, it's quite easy to introduce a bug in the masking process and it would lead to misleadingly high results. Some possibilities (from my experience writing a Transformer implementation) include softmax over the wrong dimension and the use of the wrong kind of triangular matrix. Here's an implementation of optionally masked attention I know to be correct that works for both training and inference (but depends on broadcasting/PyTorch master):

# F.softmax has strange default behavior, normalizing over dim 0 for 3D inputs
def softmax(x):
    if x.dim() == 3:
        return F.softmax(x.transpose(0, 2)).transpose(0, 2)
    return F.softmax(x)

# torch.matmul can't do (4, 3, 2) @ (4, 2) -> (4, 3)
def matmul(x, y):
    if x.dim() == y.dim():
        return x @ y
    if x.dim() == y.dim() - 1:
        return (x.unsqueeze(-2) @ y).squeeze(-2)
    return (x @ y.unsqueeze(-2)).squeeze(-2)

class Attention(nn.Module):

    def __init__(self, d_key, drop_ratio, causal):
        super().__init__()
        self.scale = math.sqrt(d_key)
        self.dropout = nn.Dropout(drop_ratio)
        self.causal = causal

    def forward(self, query, key, value):
        dot_products = matmul(query, key.transpose(1, 2))
        if query.dim() == 3 and self.causal:
            tri = key.data.new(key.size(1), key.size(1)).fill_(1).triu(1) * INF
            dot_products.data.sub_(tri.unsqueeze(0))
        return matmul(self.dropout(softmax(dot_products / self.scale)), value)
kendricktan commented 7 years ago

I think the reason why you're getting such a high accuracy is because the evaluation function uses teacher enforcing.

https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/train.py#L92

larspars commented 7 years ago

Teacher forcing is when you supply the ground-truth output for timestep t-1 as inputs in timestep t. That shouldn't have this effect on accuracy.

kendricktan commented 7 years ago

I know that, I'm saying that the accuracy metric is calculated using teacher forcing. As in the ground truth is supplied on the evaluation set, but isn't during translation, which could be a reason behind the why the metrics are so good but the results are so bad

jekbradbury commented 7 years ago

98% accuracy on a dev set is still way too high for teacher forcing.

larspars commented 7 years ago

Gotcha @kendricktan, I see what you mean then. But consider a skilled human translator, they won't get 98% accuracy per word, even if they are told the previous translated word in the sentence. There's too many ways to express the same thing.

This paper reports a 4.78 perplexity on the multi30k dataset with a RNN seq2seq model: https://www.degruyter.com/downloadpdf/j/pralin.2017.108.issue-1/pralin-2017-0020/pralin-2017-0020.pdf

My crippled 1-layer model quickly gets a perplexity of 1.1. That's too good to be true :)

kendricktan commented 7 years ago

@larspars fair point. :+1:

@jekbradbury what version of pytorch have you tested that Attention module on?

jekbradbury commented 7 years ago

It was written against PyTorch master as of three weeks ago, so it should work on 0.2.0 (to be released this week).

jadore801120 commented 7 years ago

Hi all, Sorry for the absence so far. I am busy recently.

Thanks for all these discussion. I think there does exist some bugs in my code. The accuracy is way too high. As you point out, it is very likely the current mask may somehow leak the future information, but I am not sure why and try to find out. Therefore, any further insight or observation or discussion are very welcome. Thanks!

Yu-Hsiang

ZiJianZhao commented 7 years ago

As I have pulled a request, I think your LayerNormalization module implementation is wrong, you've normalized at the wrong dimension(time) which leaks the future information: https://github.com/jadore801120/attention-is-all-you-need-pytorch/pull/17

larspars commented 7 years ago

@ZiJianZhao I think that might be it, I'm getting much more realistic accuracy and perplexity after your patch.

jadore801120 commented 7 years ago

@ZiJianZhao , I think you are right. I have just merged it! Thanks for the pointing out!

Yu-Hsiang