Open QingXu51820 opened 3 months ago
def test(test_set, teacher, student, autoencoder, teacher_mean, teacher_std, q_st_start, q_st_end, q_ae_start, q_ae_end, test_output_dir=None, desc='Running inference'): y_true = [] y_score = [] for image, target, path in tqdm(test_set, desc=desc): orig_width = image.width orig_height = image.height image = default_transform(image) image = image[None] if on_gpu: image = image.cuda() map_combined, map_st, map_ae = predict( image=image, teacher=teacher, student=student, autoencoder=autoencoder, teacher_mean=teacher_mean, teacher_std=teacher_std, q_st_start=q_st_start, q_st_end=q_st_end, q_ae_start=q_ae_start, q_ae_end=q_ae_end) map_combined = torch.nn.functional.pad(map_combined, (4, 4, 4, 4)) map_combined = torch.nn.functional.interpolate( map_combined, (orig_height, orig_width), mode='bilinear') map_combined = map_combined[0, 0].cpu().numpy() defect_class = os.path.basename(os.path.dirname(path)) if test_output_dir is not None: img_nm = os.path.split(path)[1].split('.')[0] if not os.path.exists(os.path.join(test_output_dir, defect_class)): os.makedirs(os.path.join(test_output_dir, defect_class)) file = os.path.join(test_output_dir, defect_class, img_nm + '.tiff') tifffile.imwrite(file, map_combined) y_true_image = 0 if defect_class == 'good' else 1 y_score_image = np.max(map_combined) y_true.append(y_true_image) y_score.append(y_score_image) auc = roc_auc_score(y_true=y_true, y_score=y_score) return auc * 100
efficientad.py line 262-297
efficientad.py line 262-297