gasharper / PyramidFlow

[CVPR 2023] PyramidFlow: High-Resolution Defect Contrastive Localization using Pyramid Normalizing Flow
MIT License
56 stars 14 forks source link

如何推理 #20

Open Mask0913 opened 3 months ago

Mask0913 commented 3 months ago

根据训练代码,在计算分割图的时候需要feat_mean,而feat_mean是在val数据集中计算得到的,如果没有目标域的特征图,那么将无法计算feat_mean,这种设置是否合理?

feat_sum, cnt = [0 for _ in range(num_layer)], 0
        for val_dict in val_loader:
            image = val_dict['images'].to(device)
            with torch.no_grad():
                pyramid2= flow(image) 
                cnt += 1
            feat_sum = [p0+p for p0, p in zip(feat_sum, pyramid2)]
        **feat_mean** = [p/cnt for p in feat_sum]

        # test
        flow.eval()
        diff_list, labels_list = [], []
        for test_dict in test_loader:
            image, labels = test_dict['images'].to(device), test_dict['labels']
            with torch.no_grad():
                pyramid2 = flow(image) 
                pyramid_diff = [(feat2 - template).abs() for feat2, template in zip(pyramid2, **feat_mean**)]
                diff = flow.pyramid.compose_pyramid(pyramid_diff).mean(1, keepdim=True)# b,1,h,w
                diff_list.append(diff.cpu())
                labels_list.append(labels.cpu()==1)# b,1,h,w