XinZhao0211 / OpenSetRE

This is the code for ACL2023 "Open Set Relation Extraction via Unknown-Aware Training".
4 stars 0 forks source link

train_base.sh训练集名称未修改 #1

Open CMakey opened 1 year ago

CMakey commented 1 year ago

train_base.sh中 --train_file data/${dataset}/train.json \ 好像忘了改成 --train_file data/${dataset}/train_dp.json \

CMakey commented 1 year ago

您好,请问在计算AUROC的时候,dev_ood_relation和id_relation的关系没有重叠似乎会导致报错?

CMakey commented 1 year ago

您好,请问在计算AUROC的时候,dev_ood_relation和id_relation的关系没有重叠似乎会导致报错?

在模型的运行过程中,需要使用这段代码‘ dev_auroc, dev_fpr95, thresh = validate(net, dev_dataloader, args.confidence_type, num_classes) ’计算auroc这个指标。 validate函数的代码为‘def validate(net, dataloader, mode, num_classes): net.eval() gold, predict, confidence_msp, confidence_energy = [], [], [], []

with torch.no_grad():
    for unique_id, input_ids, input_mask, target in dataloader:
        input_ids, input_mask, target = input_ids.cuda(), input_mask.cuda(), target.cuda()

        # forward
        output = net(input_ids, input_mask)

        target = target.cpu().tolist()
        pred = output.data.max(1)[1].cpu().tolist()
        msp = F.softmax(output, dim=-1).max(1)[0].cpu().tolist()
        energy = torch.logsumexp(output, 1).cpu().tolist()

        gold.extend(target)
        predict.extend(pred)
        confidence_msp.extend(msp)
        confidence_energy.extend(energy)

# acc = accuracy_score(gold, predict)
if mode == 'msp':
    confidence_score = confidence_msp
elif mode == 'energy':
    confidence_score = confidence_energy
else:
    raise Exception('Not supported mode...')

# AUROC & FPR95
pos_ = []
neg_ = []

# print(len(gold))
# print(len(confidence_score))
# invalid_labels = [label for label in gold if label >= num_classes]
# if invalid_labels:
#     print("Invalid labels found:", invalid_labels)
#     # 处理超出范围的标签
# invalid_types = [type(score) for score in confidence_score if not isinstance(score, (float, int))]
# if invalid_types:
#     print("Invalid data types found in confidence_score:", invalid_types)
#     # 处理数据类型不正确的情况

for g, c in zip(gold, confidence_score):
    if g < num_classes:
        pos_.append(c)
    else:
        neg_.append(c)

auroc, fpr95, thresh = get_measures(pos_, neg_)

# if mode == 'msp':
#     sorted_confidence = sorted(confidence_msp)
# elif mode == 'energy':
#     sorted_confidence = sorted(confidence_energy)
# else:
#     raise Exception('Not supported mode...')
# thresh_idx = int(len(sorted_confidence) * 0.05)
# thresh = sorted_confidence[thresh_idx]

print('[DEV] AUROC: {0:.4f} | FPR95: {1:.4f}| Thresh: {2:.4f}'.format(auroc, fpr95, thresh))
net.train()
return auroc, fpr95, thresh

’。在使用验证集数据的时候,gold变量的所有标记都是大于num_classes的,也就是说pos_的列表将为空,在计算auroc数据时会报错,请问应该如何处理…