huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.53k stars 26.68k forks source link

Corrupted Relative Attention in T5 Decoder #10484

Closed Slash0BZ closed 3 years ago

Slash0BZ commented 3 years ago

Environment info

platform: Mac/Ubuntu 14 transformers==2.11.0 torch==1.4.0 (GPU) python 3.6 I know this is an old version but it supports important experiments in a paper under review. Would appreciate to know what's wrong. I checked the commit log and I don't think any following commits resolve it.

Who can help

@patrickvonplaten (through slack) @patil-suraj (mentioned below) Please let me know if there is anything else I can provide! Thank you!

Information

I made an artificial binary classification data where the input sequences are near-randomly generated tokens from the T5 vocab. The output sequences are balanced “answer: correct/restaurant” (two binary tag words randomly selected). A data sample can be found here in format (input_seq \t output_seq). The custom data reader parses this data with T5Tokenizer and is_pretokenized=True (see here)

I feed the T5ForConditionalGeneration model (v.2.11.0) with input_ids, lm_labels, and their corresponding attention_masks during training. The model should not learn anything because the sequences are near-random, but in reality, it converges to a zero loss, meaning that the lm_logits from decoder actually attend to future inputs (after shift_right()) and knows the label. During evaluation where I hide the binary tag, the model always predicts positive.

To reproduce

Steps to reproduce the behavior:

  1. Use the code in this repo: https://github.com/Slash0BZ/t5-investigation
  2. Ran with sample data. I have tried both pre-trained T5-large and also randomly initialized T5-Large (written like this)

I am not sure if the training data size affects the result. I ran with a training size of 5M. I am happy to provide the full data and a trained model if actual experiments are needed.

Expected behavior

The training loss converges to near-zero and the lm_logits reflects predictions the same as the output sequence during training. However, in evaluation where the data reader hides the binary tag in the output sequence (achieve through only providing "answer:" in decoder_input_ids), the prediction is uniform.

I also tried to change the decoder_input_ids. When it is [0, 1525, 10, 2024], the prediction at position 2 is 2024. When it is [0, 1525, 10, 2062], the prediction at position 2 is 2062.

Notes: 1525->"answer", 10->":", 2024->"correct", 2062->"restaurant"

Slash0BZ commented 3 years ago

Uploaded full dataset and trained model: https://drive.google.com/drive/u/1/folders/1A7PIG1E98uuGUi8mDA2m_6T_oQp8XDhF

You can reproduce the issue by simply evaluating the test set using the trained model and observe the behavior with the aforementioned sets of decoder input ids. I suspect the issue is the same during the training process (which makes it converge to zero). I don't think I am doing anything wrong in the code, but please let me know. Thanks!

patrickvonplaten commented 3 years ago

Hey @Slash0BZ,

Hmm, this might actually be very difficult to debug since 2.11 is quite outdated by now :-/.

2 things:

1) I'm very confident that in the decoder the causal mask is always enabled, so that tokens have no access to future tokens -> they should not be able to learn to "cheat". See this line (in 2.11 version): https://github.com/huggingface/transformers/blob/b42586ea560a20dcadb78472a6b4596f579e9043/src/transformers/modeling_t5.py#L707 if you follow the function definition you see that a causal mask is generated if the model is a decoder self.config.is_decoder is True - see: https://github.com/huggingface/transformers/blob/b42586ea560a20dcadb78472a6b4596f579e9043/src/transformers/modeling_utils.py#L192

2) There was a bug in the relative positional encoding that was fixed in this PR: https://github.com/huggingface/transformers/pull/8518 . In this PR I also made sure that the original T5 and our T5 implementation give the exact same results.

Slash0BZ commented 3 years ago

Hi @patrickvonplaten, thank you for the quick response! Sorry about the version issue, 2.11.0 was the latest when I conducted all experiments for a paper under review.

I understand how the causal mask is created, and I can confirm it is working, but it cannot explain what I see. Below is what I did (recap: 2024 and 2062 are two vocab ids I used for the binary tag, 1525 and 10 represent "answer:")

with decoder_input_ids = [0, 1525, 10, 2062], inside the decoder (T5Stack), I printed input_ids, which is of size [16, 31] and of content [0, 1525, 10, 2062, 0, ... 0]. The extended_attention_mask is of size [16, 1, 31, 31] and of content (at position 2) [0, 0, 0, -10000, -10000, ... -10000]. Is everything here behaving as expected (i.e., should the first few masks be 0?) Under this, the prediction of an instance using a trained model at position 2 is 2062.

However, if I change the decoder_input_ids to [0, 1525, 10, 2024] (different binary vocab), the same model's prediction on the same instance at position 2 becomes 2024, showing that it sees what the input is at position 3, or at least it changed with different position 3 inputs.

Below is how I got the prediction at position 2, using the lm_logits directly from the forward() function in a T5ForConditionalGeneration. Please let me know if you spot any issues with it.

            outputs = model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                decoder_input_ids=inputs['decoder_input_ids'],
                # lm_labels=inputs['lm_labels'],
                decoder_attention_mask=inputs['decoder_attention_mask'],
                use_cache=False,
            )[0].cpu().numpy()
            ids = []
            for output in outputs:
                arr = []
                binary_tags = [2024, 2062]
                for val in binary_tags:
                    arr.append(output[2][val])
                argmax_idx = int(np.argmax(np.array(arr)))

Thanks again for your help. I understand how difficult it is to look at previous versions, but I need to figure out if all experiments need to be re-done.

patrickvonplaten commented 3 years ago

Hmm, the extended_attention_mask looks correct to me. Position 2 is allowed to attend to itself and to position 0 & 1.

Also, I ran the following code snippet both on current master and on 2.11 and it passes -> showing that the attention mask works correctly:

from transformers import T5ForConditionalGeneration
import torch

model = T5ForConditionalGeneration.from_pretrained('t5-small')

input_ids = torch.tensor([list(range(30))], dtype=torch.long)
decoder_input_ids = torch.ones((1, 4), dtype=torch.long)

# take output at position 2
logits_at_2 = model(input_ids, decoder_input_ids=decoder_input_ids)[0][:, 2]

decoder_input_ids[:, 3] = 10

# take output at position 2 having changed the decoder_input_ids
logits_at_2_same = model(input_ids, decoder_input_ids=decoder_input_ids)[0][:, 2]

assert abs(logits_at_2.sum().item() - logits_at_2_same.sum().item()) < 1e-3, "Error"
Slash0BZ commented 3 years ago

Thanks, @patrickvonplaten . Following your snippet, this is how you can reproduce my issue (please give it a try, it has been bugging me for weeks):

from transformers import T5ForConditionalGeneration
import torch
import numpy as np

model = T5ForConditionalGeneration.from_pretrained("trained_model")

input_ids = torch.tensor([list(range(30))], dtype=torch.long)
decoder_input_ids = torch.tensor([[0, 1525, 10, 2024]])

logits_at_2 = model(input_ids, decoder_input_ids=decoder_input_ids)[0][:, 2]
print(np.argmax(logits_at_2[0].detach().cpu().numpy()))

decoder_input_ids = torch.tensor([[0, 1525, 10, 2062]])

logits_at_2_same = model(input_ids, decoder_input_ids=decoder_input_ids)[0][:, 2]
print(np.argmax(logits_at_2_same[0].detach().cpu().numpy()))

assert abs(logits_at_2.sum().item() - logits_at_2_same.sum().item()) < 1e-3, "Error"

Where trained_model can be downloaded here: https://drive.google.com/drive/u/1/folders/1A7PIG1E98uuGUi8mDA2m_6T_oQp8XDhF It has the same config as a T5-large, just different learned weights.

Under my local env (2.11.0), it prints 2024 and 2062, and triggers the assertion error.

Given this, it seems that something corrupted during the training process, and somehow the learned weights let the model to look at future inputs. Do you have any suggestions?

Slash0BZ commented 3 years ago

@patrickvonplaten I just tried the latest Huggingface version with the snippet above using my trained model, and it also triggers the assertion error. Seems like something interesting is going on with certain model weights.

patrickvonplaten commented 3 years ago

Interesting, so our causal mask actually doesn't fully force "attending-to-next-tokens" to be impossible -> it just gives it a very large negative number before softmax (-10000) so that after softmax this value should be zero. Maybe your model has incredibly high activations that can somehow overturn the -10000. Could you maybe try the following:

Slash0BZ commented 3 years ago

Thanks @patrickvonplaten . I tried what you said, -float("inf") doesn't work because it makes the logits "NaN". So I tried -999999 and the predictions are now valid. Now that we know what the issue is, here are some of my concerns:

Please let me know. Thanks again for your help!

patrickvonplaten commented 3 years ago

Thanks for trying it out! Hmm, yeah I've never heard of such an issue before, so I assume that it will only affect the "extreme" experiments. But T5 tends to have very extreme values, which is also why we (so far) managed to run T5 only partly in fp16 mode.

We usually like to use -10000 as the masking value because it makes the model fp16 compatible...Not really sure what to do here -> we could change the masking values to -inf in general if such errors occur more often. Also pinging @LysandreJik @sgugger @patil-suraj here. Have you guys heard of a case before where the model learned to cheat the -10000 masking value?

dirkgr commented 3 years ago

We use https://github.com/allenai/allennlp/blob/f091cb9cd92e767f55659b2b59f0ffb75bc613be/allennlp/nn/util.py#L239, which ultimately boils down to using this value: torch.finfo(tensor.dtype).min.

LysandreJik commented 3 years ago

@patrickvonplaten, yes, the -10000 can totally cheat the value. We've seen that in the past in cases where the output values are passed through an argmax while the probability distribution is very uniform.

We've kept -10000 to stay as close as possible to the original BERT implementation, and we recommend to use as few padding tokens as possible for this not to have an effect (while keeping in mind that the -10000 should keep values very very small and should have a minimal impact).

@dirkgr's solution is definitely more robust and I don't think switching the -10000 value to be lower would change anyone's workflow, so I wouldn't be opposed to switching.

patil-suraj commented 3 years ago

@patrickvonplaten I never faced this issue in my T5 experiments but it does seem possible that -10000 can cause some issues because while investigating the fp16 issue we have seen that T5 produces large activation values.

And I agree with @dirkgr solution.

dorost1234 commented 3 years ago

Hi @patil-suraj @patrickvonplaten @sgugger I am experiencing similar issues with mt5, and I am getting nan always with fp16 mode, you mentioned you partly made T5 work with fp16, do you mind telling me how you managed it? I am having really a hard time with mT5 model + fp16 thanks a lot all

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

patrickvonplaten commented 3 years ago

Putting this on my ToDo-List as it seems to be quite important actually

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.