gasharper / PyramidFlow

[CVPR 2023] PyramidFlow: High-Resolution Defect Contrastive Localization using Pyramid Normalizing Flow
MIT License
56 stars 14 forks source link

visualizing the localization results #9

Open leejq666 opened 1 year ago

leejq666 commented 1 year ago

Hello, could you please publish the code for visualizing the localization results?

gasharper commented 1 year ago
# 前面是加载flow模型(略)

# 计算模板feat_mean
flow.eval()
feat_sum, cnt = [0 for _ in range(L)], 0
for val_dict in val_loader:
    image = val_dict['images'].to(device)
    with torch.no_grad():
        pyramid2= flow(image) 
        cnt += 1
    feat_sum = [p0+p for p0, p in zip(feat_sum, pyramid2)]
feat_mean = [p/cnt for p in feat_sum]

# 预处理
img2tensor = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize((x_size,x_size)),
                                        transforms.ToTensor(),
                                        transforms.Normalize(img_mean, img_std)])

for datapath in glob(osp.join(datadir, cls_name, 'test', '*', '*.png')): # 这里改成你需要的
    save_dir = osp.join('diff', '/'.join(datapath.split('/')[-4:-1]) ) # 这里改成你需要的
    os.makedirs(save_dir, exist_ok=True)
    img = np.array(Image.open(datapath).convert('RGB'))
    img = img2tensor(img).unsqueeze(0).to(device) # h,w,c
    with torch.no_grad():
        pyramid2 = flow(img)
        pyramid_diff = [(feat2 - template).abs() for feat2, template in zip(pyramid2, feat_mean)]
        diff = flow.pyramid.compose_pyramid(pyramid_diff).mean(1) # b,h,w
        np_diff = diff[0].cpu().detach().numpy()
        plt.imshow(np_diff, vmax=0.9*np_diff.max(),) # 绘制anomaly map
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        save_path = osp.join(save_dir, datapath.split('/')[-1] )
        plt.savefig(save_path,  bbox_inches='tight', pad_inches=0)
leejq666 commented 1 year ago

感谢!!!回复的非常及时,祝科研顺利!

------------------ 原始邮件 ------------------ 发件人: "gasharper/PyramidFlow" @.>; 发送时间: 2023年9月20日(星期三) 下午5:27 @.>; @.**@.>; 主题: Re: [gasharper/PyramidFlow] visualizing the localization results (Issue #9)

前面是加载flow模型(略) # 计算模板feat_mean flow.eval() featsum, cnt = [0 for in range(L)], 0 for val_dict in val_loader: image = val_dict['images'].to(device) with torch.no_grad(): pyramid2= flow(image) cnt += 1 feat_sum = [p0+p for p0, p in zip(feat_sum, pyramid2)] feat_mean = [p/cnt for p in feat_sum] # 预处理 img2tensor = transforms.Compose([transforms.ToPILImage(), transforms.Resize((x_size,x_size)), transforms.ToTensor(), transforms.Normalize(img_mean, img_std)]) for datapath in glob(osp.join(datadir, cls_name, 'test', '', '.png')): # 这里改成你需要的 save_dir = osp.join('diff', '/'.join(datapath.split('/')[-4:-1]) ) # 这里改成你需要的 os.makedirs(save_dir, exist_ok=True) img = np.array(Image.open(datapath).convert('RGB')) img = img2tensor(img).unsqueeze(0).to(device) # h,w,c with torch.no_grad(): pyramid2 = flow(img) pyramid_diff = [(feat2 - template).abs() for feat2, template in zip(pyramid2, feat_mean)] diff = flow.pyramid.compose_pyramid(pyramid_diff).mean(1) # b,h,w np_diff = diff[0].cpu().detach().numpy() plt.imshow(np_diff, vmax=0.9*np_diff.max(),) # 绘制anomaly map plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) save_path = osp.join(save_dir, datapath.split('/')[-1] ) plt.savefig(save_path, bbox_inches='tight', pad_inches=0)

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>