xhw205 / GPLinker_torch

CMeIE/CBLUE/CHIP/实体关系抽取/SPO抽取
207 stars 14 forks source link

有好心老哥提供一下评估函数吗 #6

Open beerpig opened 2 years ago

beerpig commented 2 years ago

有好心老哥提供一下评估函数吗?不胜感激🙏

lili-li-cpu commented 2 years ago

您好呀,这个的评估函数您实现了吗

xxmNIe commented 1 year ago

同求

258508 commented 1 year ago

同求,不胜感激

qjk7 commented 11 months ago

根据预测函数写的: def evaluate(model, threshold=-0.5): model.eval() predict_num, gold_num, correct_num = 0, 0, 0

def to_tuple(data):
    tuple_data = []
    for i in data:
        tuple_data.append((i.get('subject'),i.get('subject_type')+'_'+i.get('predicate')+'_'+i.get('object_type').get('@value'),i.get('object').get('@value')))
    return tuple(tuple_data)

with torch.no_grad():
    with open(args_path["val_file"]) as f:
        line_list = [json.loads(text.rstrip()) for text in f.readlines()]
        for line in tqdm(line_list):
            text = line['text']
            token2char_span_mapping = tokenizer(text, return_offsets_mapping=True, max_length=256)["offset_mapping"]
            new_span, entities = [], []
            for i in token2char_span_mapping:
                if i[0] == i[1]:
                    new_span.append([])
                else:
                    if i[0] + 1 == i[1]:
                        new_span.append([i[0]])
                    else:
                        new_span.append([i[0], i[-1] - 1])
            threshold = 0.0
            encoder_txt = tokenizer.encode_plus(text, max_length=256)
            input_ids = torch.tensor(encoder_txt["input_ids"]).long().unsqueeze(0).to(device)
            token_type_ids = torch.tensor(encoder_txt["token_type_ids"]).unsqueeze(0).to(device)
            attention_mask = torch.tensor(encoder_txt["attention_mask"]).unsqueeze(0).to(device)
            scores = net(input_ids, attention_mask, token_type_ids)
            outputs = [o[0].data.cpu().numpy() for o in scores]
            subjects, objects = set(), set()
            outputs[0][:, [0, -1]] -= np.inf
            outputs[0][:, :, [0, -1]] -= np.inf
            for l, h, t in zip(*np.where(outputs[0] > 0)):
                if l == 0:
                    subjects.add((h, t))
                else:
                    objects.add((h, t))
            spoes = set()
            for sh, st in subjects:
                for oh, ot in objects:
                    p1s = np.where(outputs[1][:, sh, oh] > threshold)[0]
                    p2s = np.where(outputs[2][:, st, ot] > threshold)[0]
                    ps = set(p1s) & set(p2s)
                    for p in ps:
                        spoes.add((
                            text[new_span[sh][0]:new_span[st][-1] + 1], id2schema[p],
                            text[new_span[oh][0]:new_span[ot][-1] + 1]
                        ))
            triple = line["spo_list"]
            triple = set(to_tuple(triple))
            pred = set(spoes)
            correct_num += len(triple & pred)
            predict_num += len(pred)
            gold_num += len(triple)

        recall = correct_num / (gold_num + 1e-10)
        precision = correct_num / (predict_num + 1e-10)
        f1_score = 2 * recall * precision / (recall + precision + 1e-10)
        print("precision: {:5.4f} recall: {:5.4f} f1_score: {:5.4f}".format(precision, recall, f1_score))
        return precision, recall, f1_score