allenai / longformer

Longformer: The Long-Document Transformer
https://arxiv.org/abs/2004.05150
Apache License 2.0
2.05k stars 276 forks source link

I am not able to set global attention mask. I have although given two sep tokens between question and context #97

Open rudraksh97 opened 4 years ago

rudraksh97 commented 4 years ago

outputs = model(*inputs) File "C:\Users\Rudraksh\Anaconda3\envs\bert_pre\lib\site-packages\torch\nn\modules\module.py", line 547, in call result = self.forward(input, **kwargs) File "d:\transformers\src\transformers\modeling_longformer.py", line 1345, in forward global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id) File "d:\transformers\src\transformers\modeling_longformer.py", line 84, in _compute_global_attention_mask question_end_index = _get_question_end_index(input_ids, sep_token_id) File "d:\transformers\src\transformers\modeling_longformer.py", line 73, in _get_question_end_index ), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set global_attention_mask manually in the forward function to avoid this error." AssertionError: There should be exactly three separator tokens: 2 in every sample for questions answering. You might also consider to set global_attention_mask manually in the forward function to avoid this error.

rudraksh97 commented 4 years ago

one of the input id ([ 0, 10534, 42299, 398, 4, 134, 46004, 406, 46419, 42360, 246, 4, 134, 4, 176, 4, 134, 368, 36600, 2, 2, 2, 6968, 605, 17184, 560, 39567, 10534, 42299, 398, 4, 134, 46004, 406, 6, 4297, 16625, 42360, 21747, 26692, 3654, 29933, 627, 44240, 21747, 4, 757, 134, 4, 398, 34045, 196, 627, 28481, 154, 35, 44223, 35, 627, 39161, 1452, 119, 42360, 4929, 40545, 28023, 31002, 21747, 354, 246, 4, 134, 4, 288, 4, 246, 6, 4297, 46419, 21747, 246, 4, 134, 4, 176, 4, 134, 368, 36600, 4, 2, 2, 2, 1452, 119, 41536, 267, 7706, 1990, 11131, 14668, 21747, 134, 4, 134, 4, 246, 4, 288, 12, 33557, 26947, 2, 2, 2, 41536, 868, 42018, 873, 47796, 6968, 7424, 41536, 267, 7706, 1990, 11131, 14668, 7761, 627, 10212, 3427, 40933, 1580, 10534, 11953, 1409, 10928, 627, 6930, 18731, 20094, 7755, 282, 30806, 6025, 1322, 39355, 179, 9226, 43017, 4, 41536, 42739, 267, 7706, 13445, 1990, 11131, 14668, 3809, 1033, 31586, 627, 12592, 28023, 1990, 26264, 35296, 27454, 15313, 1258, 1640, 366, 45071, 43, 28746, 18, 12592, 38928, 14868, 1990, 462, 18957, 23687, 463, 7443, 42502, 24894, 8475, 28023, 131, 8529, 42360, 4929, 463, 37285, 28023, 4, 267, 7706, 1990, 11131, 14668, 5087, 33227, 38557, 30764, 14668, 3866, 225, 24198, 16918, 281, 26601, 131, 15526, 38575, 154, 131, 463, 405, 11131, 14668, 4, 267, 7706, 1990, 11131, 14668, 7333, 102, 30695, 1116, 24894, 8475, 28023, 35, 42360, 4929, 28023, 39355, 1409, 1452, 119, 8716, 42360, 4929, 40545, 28023, 31002, 463, 37285, 28023, 4, 1990, 4321, 31480, 9006, 267, 7706, 1990, 11131, 14668, 131, 7048, 627, 267, 7706, 1990, 11131, 14668, 43017, 1258, 10975, 8166, 640, 1401, 4, 1452, 119, 4, 175, 73, 22930, 73, 44228, 15497, 73, 29, 1090, 330, 16312, 1215, 134, 4, 134, 4, 246, 4, 288, 73, 175, 4, 1452, 119, 4, 46651, 4, 37447, 73, 46651, 1215, 636, 12, 8361, 8596, 4, 6660, 8174, 21747, 134, 4, 134, 4, 246, 4, 288, 354, 102, 16320, 13043, 23053, 1116, 267, 7706, 1990, 11131, 14668, 21747, 134, 4, 134, 11070, 5632, 14377, 5000, 246, 4, 41536, 154, 267, 7706, 1990, 11131, 14668, 6930, 18731, 39472, 6968, 7424, 41536, 39472, 1409, 10928, 627, 1452, 119, 41536, 19709, 3340, 2716, 368, 8166, 4, 1594, 642, 31497, 11156, 3876, 41536, 19709, 131, 13437, 405, 4255, 1075, 23050, 2802, 10687, 627, 175, 8293, 42883, 1116, 37782, 20094, 6025, 405, 31931, 293, 4, 39080, 35, 1594, 6968, 3698, 41536, 19709, 560, 41536, 627, 41038, 131, 13040, 6968, 7424, 19726, 3698, 405, 560, 16435, 15664, 627, 10800, 4189, 1116, 1250, 627, 11828, 10887, 42018, 6025, 6968, 3955, 24919, 179, 627, 39035, 4, 134, 4, 23033, 6968, 13124, 627, 39567, 1258, 131, 15954, 627, 267, 7706, 1990, 11131, 14668, 21756, 17467, 10975, 8166, 640, 1401, 4, 1452, 119, 4, 175, 73, 22930, 73, 21061, 4, 605, 7485, 116, 1343, 5214, 267, 7706, 9426, 1225, 541, 5982, 29369, 742, 1990, 19593, 10536, 40512, 368, 20079, 39567, 1258, 179, 25384, 2485, 4, 176, 4, 417, 5906, 13523, 627, 8738, 1116, 29912, 12, 14175, 39472, 560, 41536, 2])

rudraksh97 commented 4 years ago

Ideally it should look like [0] [question token ids] [2,2] [answer token ids] [2]

ibeltagy commented 4 years ago

as the error indicates, you need to pass the global attention pattern to the forward function, something like:

result = self.forward(input_ids, attention_mask, global_attention_mask)
mihaidobri commented 4 years ago

I have a question related to this topic. For fine-tuning task on Text Classification, do we need to set global_attention_mask ?

( ex when using from Hugging Face library)

 model = LongformerForSequenceClassification.from_pretrained('allenai/longformer-base-4096', 
                                                            num_labels = n_labels, # The number of output labels )                                                        

I was reading the documentation from https://huggingface.co/transformers/model_doc/longformer.html#transformers.LongformerForSequenceClassification and the example code they have is not mentioning anything specific.

Looking at https://github.com/allenai/longformer in the provided example, I see :

attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
attention_mask[:, [1, 4, 21,]] =  2  # Set global attention based on the task. For example,
                                     # classification: the <s> token
                                     # QA: question tokens 

So in my train function, should I assume that I have to add = " 2 " ( since I am doing fine-tuning for text classification)?

outputs = model(input_ids=batch['input_ids'].type(torch.long).to(device), 
                             attention_mask= 2,
                            labels=batch['labels'].type(torch.long).to(device))

Is this correct?

ibeltagy commented 4 years ago

I agree, documentation is not clear and it needs to be updated. To answer your question, in most cases, you don't need to specify global attention yourself for LongformerForSequenceClassification because LongformerForSequenceClassification.forward automatically sets it for you on the CLS token (check here). If you want a different pattern, you can manually specify it in the input, but if your pattern is the same (global attention on the CLS token), then you can rely on LongformerForSequenceClassification.forward to do the right thing.

jaideep11061982 commented 1 year ago

@ibeltagy I too face this issue any suggestions on this. I use hugging trainer

first_sentence = [ '<s>'+example['context']+tokenizer.bos_token + example['prompt' ] +'</s>']* 5 
            #first_sentence=['<s>'+example['context'] +' '+ '</s>' +' '+ '</s>' + ' ' + example['prompt']+' '+'</s>']*5
        #print('prompt',example['prompt'])
            #second_sentences = [tokenizer.bos_token+ example['prompt'] + tokenizer.eos_token +tokenizer.bos_token+
            #example[option] + tokenizer.eos_token for option in 'ABCDE']
            second_sentence=  [ '</s>'+ example[option] +'</s>' for option in 'ABCDE']

            tokenized_example = tokenizer (first_sentence, second_sentences,
                                      truncation='only_first', 
                                  max_length=max_length, add_special_tokens=False
                                  )
afshinrahimi commented 3 months ago

I have the same problem when I use longformer for QA. There is an assert in modeling_longformer.py that makes sure the number of sep_token_ids matches 3 x batch size. I have made sure that it is the case by looking at tokenized samples.

def _get_question_end_index(input_ids, sep_token_id):
    """
    Computes the index of the first occurrence of `sep_token_id`.
    """

    sep_token_indices = (input_ids == sep_token_id).nonzero()
    batch_size = input_ids.shape[0]

    assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
    assert sep_token_indices.shape[0] == 3 * batch_size, (
        f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You"
        " might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
    )
    return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]

I found that if I set CUDA_VISIBLE_DEVICES to only one gpu, the assert is satisfied, otherwise with multiple gpus throws an error and sep_token_indices will be an empty tensor `tensor([], device='cuda:5', size=(0, 2), dtype=torch.int64).

I am not sure yet what the reason is, this doesn't happen when I infer on one or two individual instances.