Open bhavsarpratik opened 4 years ago
Experienced same error. Switching to bert produces an IndexError....
File "Initial_test_k_lm.py", line 198, in
optimizer_type="adamw") #adamw /lamb File "/home/kino/.local/lib/python3.6/site-packages/fast_bert/learner_lm.py", line 143, in fit outputs = self.model(inputs, masked_lm_labels=labels) File "/home/kino/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(*input, *kwargs) File "/home/kino/.local/lib/python3.6/site-packages/transformers/modeling_bert.py", line 1003, in forward masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) File "/home/kino/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in call result = self.forward(input, **kwargs) File "/home/kino/.local/lib/python3.6/site-packages/torch/nn/modules/loss.py", line 916, in forward ignore_index=self.ignore_index, reduction=self.reduction) File "/home/kino/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 2021, in cross_entropy return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) File "/home/kino/.local/lib/python3.6/site-packages/torch/nn/functional.py", line 1838, in nll_loss ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) IndexError: Target -1 is out of bounds.
See : Version 2.4.1 breaks run_lm_finetuning.py, version 2.3.0 runs fine
For us in fast-bert , modifying the function mask_tokens in data_lm.py:
labels[~masked_indices] = -1 to this line labels[~masked_indices] = -100
restores training operation locally for both cuda and cpu. See:Breaking Changes:Ignored indices in PyTorch loss computing (@LysandreJik)
Please fix , please ...
Same issue. Can anyone help?
I am getting this issue with BertLMLearner and not with BertLearner. I tried to debug a lot and also changed cuda versions but couldn't make it work. I am not getting this error when I use run_language_modeling.py of transformer.
RuntimeError Traceback (most recent call last)