Closed hanyc0914 closed 3 years ago
比赛结束后测试集的标签发放给我们了,然后我们写了一段代码计算。具体的代码可以参考:
import re
def get_entities(lines):
result = []
cur = 0
for line in lines:
entities = []
last_end = 0
for m in re.finditer(r"{{(.*?)::?(.*?)}}", line):
label = m.group(1).upper()
word = m.group(2)
cur += m.start() - last_end
last_end = m.end()
if word:
entities.append((label, cur, cur + len(word)))
cur += len(word)
cur += len(line) - last_end
result.extend(entities)
return result
def getlines(path):
with open(path) as f:
lines = f.read().split("\n")
lines = list(filter(lambda line: line, lines))
return lines
def main():
ground_truth_path = 'zs100w_0921_wyq_up.txt'
pred_path = 'result.txt'
ground_truth_lines = getlines(ground_truth_path)
pred_lines = getlines(pred_path)
for i, (gtl, pl) in enumerate(zip(ground_truth_lines, pred_lines)):
gtl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', gtl)
pl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', pl)
assert gtl_no_label == pl_no_label, f"Different data in row {i}: \n{gtl} \n{pl}"
true_entities = set(get_entities(ground_truth_lines))
pred_entities = set(get_entities(pred_lines))
nb_correct = len(true_entities & pred_entities)
nb_pred = len(pred_entities)
nb_true = len(true_entities)
p = nb_correct / nb_pred if nb_pred > 0 else 0
r = nb_correct / nb_true if nb_true > 0 else 0
score = 2 * p * r / (p + r) if p + r > 0 else 0
print('P', p)
print('R', r)
print('F1', score)
true_entities_dict = {}
for t, start, end in true_entities:
if t not in true_entities_dict:
true_entities_dict[t] = set()
true_entities_dict[t].add((start, end))
pred_entities_dict = {}
for t, start, end in pred_entities:
if t not in pred_entities_dict:
pred_entities_dict[t] = set()
pred_entities_dict[t].add((start, end))
nb_correct_dict = {k: len(true_entities_dict[k] & pred_entities_dict[k]) for k in true_entities_dict}
nb_pred_dict = {k: len(pred_entities_dict[k]) for k in true_entities_dict}
nb_true_dict = {k: len(true_entities_dict[k]) for k in true_entities_dict}
p_dict = {k: nb_correct_dict[k] / nb_pred_dict[k] if nb_pred_dict[k] > 0 else 0 for k in true_entities_dict}
r_dict = {k: nb_correct_dict[k] / nb_true_dict[k] if nb_true_dict[k] > 0 else 0 for k in true_entities_dict}
score_dict = {k: 2 * p_dict[k] * r_dict[k] / (p_dict[k] + r_dict[k]) if p_dict[k] + r_dict[k] > 0 else 0 for k in
true_entities_dict}
print('P', p_dict)
print('R', r_dict)
print('F1', score_dict)
if __name__ == '__main__':
main()
好的,太感谢您了!还有一个问题,词表大小是 23292,但是网络最后一层输出的维度是 768,这样计算交叉熵损失函数会报错说label 范围超出实际维度,例如 logits.shape = [32,204,768], labels.shape = [32,204],那么 labels 中实际元素值肯定会比 768 大的,请问这个问题怎么解决呢?网络最后一层输出为什么没有设置成 23292 呢?
我上传的模型是transformers.RobertaForMaskedLM
。如果使用transformers.RobertaModel
将会抛弃lm_head层,所以直接输出hidden size。你可以使用transformers.RobertaForMaskedLM
,最后会映射到词表的维度。
具体参考huggingface的文档
https://huggingface.co/transformers/model_doc/roberta.html#transformers.RobertaForMaskedLM
表格中 recall, F1 score 等指标是如何计算得到的呢?