kmkurn / pytorch-crf

(Linear-chain) Conditional random field in PyTorch.
https://pytorch-crf.readthedocs.io
MIT License
935 stars 151 forks source link

IndexError: index -100 is out of bounds for dimension 0 with size 9 #106

Closed venti07 closed 1 year ago

venti07 commented 1 year ago

Hello, I am trying to use a BERTCRF model. Unfortunately, the following error message appears: IndexError: index -100 is out of bounds for dimension 0 with size 9

I have a notebook from Transformers Notebooks for token classifiacation as a base and would like to use a BERTCRF Model instead of the AutoModelForTokenClassification. https://huggingface.co/docs/transformers/notebooks

I have set up a notebook and inserted the appropriate BERTCRF models: https://github.com/venti07/share/blob/main/classification_bertcrf.ipynb

Maybe someone can quickly find the error. I would appreciate it very much. Thanks in advance!

TidorP commented 1 year ago

During training, it seems you still get the padding index (-100) which is not expected into torch CRF. You need to remove it.

kmkurn commented 1 year ago

@TidorP is right. Please set those indices to a value between 0-8 before passing it through the CRF layer. You can restore them afterwards.

siddharthtumre commented 1 year ago

Instead of removing, I have tried passing a mask to the CRF. But the problem here is it requires for the first column to be '1'. But the first index is a [CLS] token which has a label of -100 after padding.

How to overcome this?

kmkurn commented 1 year ago

@siddharthtumre Just remove the [CLS] token before feeding into the CRF layer. So something like

scores = scores[:, 1:]
tags = tags[:, 1:]

should work (assuming the first dim is the batch size).

atul47B commented 1 year ago

I am facing the same error where my labels tensor is [512, 4]. How can I remove the -100 from every batch sample?

kmkurn commented 1 year ago

@atul47B You can use something like

is_pad = tags == -100
tags.masked_fill_(is_pad, 0)
loss = -crf(emissions, tags, mask=~is_pad)

The crf forward computation will ignore positions where mask is False regardless of the tag/label value.

kmkurn commented 1 year ago

Closing because the issue is resolved.