Open CMakey opened 1 year ago
您好,请问在计算AUROC的时候,dev_ood_relation和id_relation的关系没有重叠似乎会导致报错?
您好,请问在计算AUROC的时候,dev_ood_relation和id_relation的关系没有重叠似乎会导致报错?
在模型的运行过程中,需要使用这段代码‘ dev_auroc, dev_fpr95, thresh = validate(net, dev_dataloader, args.confidence_type, num_classes) ’计算auroc这个指标。 validate函数的代码为‘def validate(net, dataloader, mode, num_classes): net.eval() gold, predict, confidence_msp, confidence_energy = [], [], [], []
with torch.no_grad():
for unique_id, input_ids, input_mask, target in dataloader:
input_ids, input_mask, target = input_ids.cuda(), input_mask.cuda(), target.cuda()
# forward
output = net(input_ids, input_mask)
target = target.cpu().tolist()
pred = output.data.max(1)[1].cpu().tolist()
msp = F.softmax(output, dim=-1).max(1)[0].cpu().tolist()
energy = torch.logsumexp(output, 1).cpu().tolist()
gold.extend(target)
predict.extend(pred)
confidence_msp.extend(msp)
confidence_energy.extend(energy)
# acc = accuracy_score(gold, predict)
if mode == 'msp':
confidence_score = confidence_msp
elif mode == 'energy':
confidence_score = confidence_energy
else:
raise Exception('Not supported mode...')
# AUROC & FPR95
pos_ = []
neg_ = []
# print(len(gold))
# print(len(confidence_score))
# invalid_labels = [label for label in gold if label >= num_classes]
# if invalid_labels:
# print("Invalid labels found:", invalid_labels)
# # 处理超出范围的标签
# invalid_types = [type(score) for score in confidence_score if not isinstance(score, (float, int))]
# if invalid_types:
# print("Invalid data types found in confidence_score:", invalid_types)
# # 处理数据类型不正确的情况
for g, c in zip(gold, confidence_score):
if g < num_classes:
pos_.append(c)
else:
neg_.append(c)
auroc, fpr95, thresh = get_measures(pos_, neg_)
# if mode == 'msp':
# sorted_confidence = sorted(confidence_msp)
# elif mode == 'energy':
# sorted_confidence = sorted(confidence_energy)
# else:
# raise Exception('Not supported mode...')
# thresh_idx = int(len(sorted_confidence) * 0.05)
# thresh = sorted_confidence[thresh_idx]
print('[DEV] AUROC: {0:.4f} | FPR95: {1:.4f}| Thresh: {2:.4f}'.format(auroc, fpr95, thresh))
net.train()
return auroc, fpr95, thresh
’。在使用验证集数据的时候,gold变量的所有标记都是大于num_classes的,也就是说pos_的列表将为空,在计算auroc数据时会报错,请问应该如何处理…
train_base.sh中
--train_file data/${dataset}/train.json \
好像忘了改成--train_file data/${dataset}/train_dp.json \