thunlp / OpenPrompt

An Open-Source Framework for Prompt-Learning.
https://thunlp.github.io/OpenPrompt/
Apache License 2.0
4.38k stars 455 forks source link

No loss ids in truncated sequence (GPT2 model) #279

Open ngavcc opened 1 year ago

ngavcc commented 1 year ago

I am using GPT2 model + MixedTemplate for a text classification task When i am training, an exception occur at reshape tensor operation I have made some inspection and i realized that when input text is too long, it turn out that all value in loss_ids is zero, and model failed to train (exception at reshape tensor when model getting output for "mask" position) This is an example of batch data (from PromptDataLoader), when i use batch size = 2048, i got loss_ids of shape torch.Size([2048, 128]), but when i sum it's elements value (loss_ids.sum()), i only get tensor(2047). that means one data instance get no loss_ids