Ethan-yt / guwenbert

GuwenBERT: 古文预训练语言模型(古文BERT) A Pre-trained Language Model for Classical Chinese (Literary Chinese)
Apache License 2.0
493 stars 40 forks source link

请问评估指标是如何确定的呢? #9

Closed hanyc0914 closed 3 years ago

hanyc0914 commented 3 years ago

表格中 recall, F1 score 等指标是如何计算得到的呢?

Ethan-yt commented 3 years ago

比赛结束后测试集的标签发放给我们了,然后我们写了一段代码计算。具体的代码可以参考:

import re

def get_entities(lines):
    result = []
    cur = 0
    for line in lines:
        entities = []
        last_end = 0
        for m in re.finditer(r"{{(.*?)::?(.*?)}}", line):
            label = m.group(1).upper()
            word = m.group(2)
            cur += m.start() - last_end
            last_end = m.end()
            if word:
                entities.append((label, cur, cur + len(word)))
                cur += len(word)
        cur += len(line) - last_end
        result.extend(entities)
    return result

def getlines(path):
    with open(path) as f:
        lines = f.read().split("\n")
        lines = list(filter(lambda line: line, lines))
        return lines

def main():
    ground_truth_path = 'zs100w_0921_wyq_up.txt'
    pred_path = 'result.txt'

    ground_truth_lines = getlines(ground_truth_path)
    pred_lines = getlines(pred_path)

    for i, (gtl, pl) in enumerate(zip(ground_truth_lines, pred_lines)):
        gtl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', gtl)
        pl_no_label = re.sub(r"{{(.*?)::?(.*?)}}", r'\2', pl)
        assert gtl_no_label == pl_no_label, f"Different data in row {i}: \n{gtl} \n{pl}"

    true_entities = set(get_entities(ground_truth_lines))
    pred_entities = set(get_entities(pred_lines))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    nb_true = len(true_entities)
    p = nb_correct / nb_pred if nb_pred > 0 else 0
    r = nb_correct / nb_true if nb_true > 0 else 0
    score = 2 * p * r / (p + r) if p + r > 0 else 0
    print('P', p)
    print('R', r)
    print('F1', score)

    true_entities_dict = {}
    for t, start, end in true_entities:
        if t not in true_entities_dict:
            true_entities_dict[t] = set()
        true_entities_dict[t].add((start, end))

    pred_entities_dict = {}
    for t, start, end in pred_entities:
        if t not in pred_entities_dict:
            pred_entities_dict[t] = set()
        pred_entities_dict[t].add((start, end))

    nb_correct_dict = {k: len(true_entities_dict[k] & pred_entities_dict[k]) for k in true_entities_dict}
    nb_pred_dict = {k: len(pred_entities_dict[k]) for k in true_entities_dict}
    nb_true_dict = {k: len(true_entities_dict[k]) for k in true_entities_dict}

    p_dict = {k: nb_correct_dict[k] / nb_pred_dict[k] if nb_pred_dict[k] > 0 else 0 for k in true_entities_dict}
    r_dict = {k: nb_correct_dict[k] / nb_true_dict[k] if nb_true_dict[k] > 0 else 0 for k in true_entities_dict}
    score_dict = {k: 2 * p_dict[k] * r_dict[k] / (p_dict[k] + r_dict[k]) if p_dict[k] + r_dict[k] > 0 else 0 for k in
                  true_entities_dict}
    print('P', p_dict)
    print('R', r_dict)
    print('F1', score_dict)

if __name__ == '__main__':
    main()
hanyc0914 commented 3 years ago

好的,太感谢您了!还有一个问题,词表大小是 23292,但是网络最后一层输出的维度是 768,这样计算交叉熵损失函数会报错说label 范围超出实际维度,例如 logits.shape = [32,204,768], labels.shape = [32,204],那么 labels 中实际元素值肯定会比 768 大的,请问这个问题怎么解决呢?网络最后一层输出为什么没有设置成 23292 呢? image

Ethan-yt commented 3 years ago

我上传的模型是transformers.RobertaForMaskedLM。如果使用transformers.RobertaModel将会抛弃lm_head层,所以直接输出hidden size。你可以使用transformers.RobertaForMaskedLM,最后会映射到词表的维度。

具体参考huggingface的文档

https://huggingface.co/transformers/model_doc/roberta.html#transformers.RobertaForMaskedLM