huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.49k stars 26.9k forks source link

Longformer finetuning on TPUs IndexError: tuple index out of range #6693

Closed wassimseif closed 3 years ago

wassimseif commented 4 years ago

Environment info

Who can help

Longformer/Reformer: @patrickvonplaten

Information

Model I am using (Bert, XLNet ...): longformer: allenai/longformer-large-4096

The problem arises when using:

The tasks I am working on is:

To reproduce

Steps to reproduce the behavior: My Model

class LongFormerBaseUncased(nn.Module):
    def __init__(self):
        super(LongFormerBaseUncased, self).__init__()
        self.bert = transformers.LongformerModel.from_pretrained(
            "allenai/longformer-large-4096",
            gradient_checkpointing=True
        )
        self.bert_drop = nn.Dropout(config.dropout)
        self.out = nn.Linear(1024, config.output_num_classes)

    def forward(self, ids, mask):
        _, o2 = self.bert(ids, attention_mask = mask)
        bo = self.bert_drop(o2)
        output = self.out(bo)
        return output
tokenizer = transformers.LongformerTokenizer.from_pretrained(
    "allenai/longformer-base-4096"
)
text = "Very Long text"
tokenized = self.tokenizer.tokenize(text)
inputs = self.tokenizer.encode_plus(
            tokenized,
            is_pretokenized=True,
            max_length=4096,
            pad_to_max_length=True,
            truncation=True,
            )
ids = inputs["input_ids"]
mask = inputs["attention_mask"]

ids = ids.to(device, dtype=torch.long)
mask = mask.to(device, dtype=torch.long)
targets = targets.to(device, dtype=torch.float)

#This throws the error
outputs = model(ids=ids, mask=mask)

Error

Exception in device=TPU:0: tuple index out of range
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 228, in _start_fn
    fn(gindex, *args)
  File "<ipython-input-14-9a008098ce7f>", line 3, in _mp_fn
    a = run()
  File "<ipython-input-12-9c37f47d0144>", line 156, in run
    train_fn(train_data_loader, model, optimizer, device, scheduler)
  File "<ipython-input-12-9c37f47d0144>", line 26, in train_fn
    outputs = model(ids=ids, mask=mask)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-9-b68f74a484cf>", line 12, in forward
    _, o2 = self.bert(ids, attention_mask = mask)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_longformer.py", line 1004, in forward
    output_hidden_states=output_hidden_states,
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_longformer.py", line 692, in forward
    create_custom_forward(layer_module), hidden_states, attention_mask,
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/checkpoint.py", line 163, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/checkpoint.py", line 74, in forward
    outputs = run_function(*args)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_longformer.py", line 687, in custom_forward
    return module(*inputs, output_attentions)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/transformers/modeling_longformer.py", line 658, in forward
    self_attn_outputs = self.attention(hidden_states, attention_mask, output_attentions=output_attentions,)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 577, in __call__
    result = self.forward(*input, **kwargs)
IndexError: tuple index out of range
An exception has occurred, use %tb to see the full traceback.
patrickvonplaten commented 4 years ago

Hey @wassimseif, sadly neither Longformer nor Reformer works on PyTorch/XLA . There is just too much dynamic tensor reshaping happening. I think @ibeltagy made Longformer work on PyTorch/XLA when respecting certain limitations (only local attention)

wassimseif commented 4 years ago

Hey @patrickvonplaten, Understood. Is there some wiki that specifies which model works on XLA & which don't ?

ibeltagy commented 4 years ago

@wassimseif, running longformer on pytroch-xla is tracked in this issue https://github.com/allenai/longformer/issues/101. I am aiming to make that code available soon, probably this week.

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

mabdullah1994 commented 3 years ago

Hi. Got the same error. Any update on this issue? Thanks!