Open scalaboy opened 3 years ago
请问下,这个invalid batch 是怎么个逻辑计算的?代码没看明白、
raw_arr = np.column_stack((preds, labels)) raw_arr = sorted(raw_arr, key=lambda d: d[0]) auc = 0. fp1, tp1, fp2, tp2 = 0., 0., 0., 0. for record in raw_arr: fp2 += 1 - record[1] tp2 += record[1] auc += (fp2 - fp1) * (tp2 + tp1) fp1, tp1 = fp2, tp2 threshold = len(preds) - 1e-3 if tp2 > threshold or fp2 > threshold: print('%s invalid batch' % self.name) return
请问下,这个invalid batch 是怎么个逻辑计算的?代码没看明白、