Closed keenjo closed 1 year ago
It seems like in the original masking functions the labels are cloned when created from the input ids:
As in line 668 of pre-train/common/utils.py
-> labels = batch["input_ids"].clone()
However, they are simply renamed when the labels are created from the AMR graphs, which I believe is always called batch['labels']:
Such as in line 674 of pre-train/common/utils.py
-> labels = batch['labels']
By cloning the AMR graphs as well when they are the labels I seemed to have solved the issue:
So anytime I saw what was written in line 674 of pre-train/common/utils.py
, I changed it to the following labels = batch['labels'].clone()
Hello,
I'm trying to reproduce your pretraining script,
run_multitask_unified_pretraining.py
and I have been struggling with the following error:I've installed all of the packages into my own conda environment using your
requirements.yml
file.The error seems to be tied to the labels and the fact that we are doing multitask training here, as I have been able to get the script to run without any issues if I only define one task. I have tried many different options to solve the issue but nothing seems to be working. Do you have any ideas?