Open liusenling opened 5 months ago
It seams to have something to do with the dataset from ner_dev_loader. Or the model. Can you shar with me what have you changed?
In terms of the code, I only added a length limit to the load_ner_dataset section
def load_ner_dataset(path_to_txt, path_to_images, max_length=510, load_image: bool = True) -> CustomDataset:
tokens = []
image_id = None
datas = []
with open(str(path_to_txt), encoding='utf-8') as txt_file:
for line in txt_file:
line = line.rstrip()
if line.startswith(IMGID_PREFIX):
image_id = line[len(IMGID_PREFIX):]
elif line != '':
text, label = line.split(' ')
if text == '' or text.isspace() or text in SPECIAL_TOKENS or text.startswith(
URL_PREFIX):
text = UNKNOWN_TOKEN
if len(tokens) <= max_length:
tokens.append(Token(text, constants.LABEL_TO_ID[label]))
else:
datas.append(Data(Sentence(tokens),
ImageData(f'{image_id}.jpg')))
tokens = [] #
datas.append(Data(Sentence(tokens), ImageData(f'{image_id}.jpg')))
return CustomDataset(datas, path_to_images,
load_image)
Replace the class "SentenceDataPoint" to and try again:
class Sentence(DataPoint):
def __init__(self, tokens: List[Token] = None, text: str = None):
super().__init__()
self.tokens: List[Token] = tokens
self.text = text
def __len__(self):
return len(self.tokens)
def __getitem__(self, index: int):
return self.tokens[index]
def __iter__(self):
return iter(self.tokens)
def __str__(self):
return self.text if self.text else ''.join([token.text for token in self.tokens])
Hello, have you ever encountered ValueError: max() arg is an empty sequence when you run it? Traceback (most recent call last): File "train1.py", line 70, in
dev_f1, dev_report = evaluate(model, ner_dev_loader)
File "/root/autodl-fs/RpBERT-GCN-NER-main/utils.py", line 56, in evaluate
report = classification_report(true_labels, pred_labels, digits=4, mode='strict', scheme=IOB2)
File "/root/miniconda3/lib/python3.8/site-packages/seqeval/metrics/sequence_labeling.py", line 670, in classification_report
return cr(y_true, y_pred,
File "/root/miniconda3/lib/python3.8/site-packages/seqeval/metrics/v1.py", line 390, in classification_report
name_width = max(map(len, target_names))
ValueError: max() arg is an empty sequence.