jacobgil / pytorch-grad-cam

Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.
https://jacobgil.github.io/pytorch-gradcam-book
MIT License
10.06k stars 1.52k forks source link

Wrong heatmap of Swin-Transformer #124

Closed lawsonxwl closed 2 years ago

lawsonxwl commented 2 years ago

My code: `if name == 'main': """ python swinT_example.py -image-path Example usage of using cam-methods on a SwinTransformers network.

"""
label_map_path = '../../imagenet1000_clsidx_to_labels.txt'
with open(label_map_path, 'r') as file:
    label = eval(file.read())
args = get_args()
methods = \
    {"gradcam": GradCAM, 
     "scorecam": ScoreCAM, 
     "gradcam++": GradCAMPlusPlus,
     "ablationcam": AblationCAM,
     "xgradcam": XGradCAM,
     "eigencam": EigenCAM,
     "eigengradcam": EigenGradCAM}

if args.method not in list(methods.keys()):
    raise Exception(f"method should be one of {list(methods.keys())}")

model = timm.create_model('swin_tiny_patch4_window7_224',
                          pretrained=True
                          )
model.eval()

if args.use_cuda:
    model = model.cuda()

target_layer = model.layers[-1].blocks[-2].norm1

if args.method not in methods:
    raise Exception(f"Method {args.method} not implemented")

cam = methods[args.method](model=model, 
                           target_layer=target_layer,
                           use_cuda=args.use_cuda,
                           reshape_transform=reshape_transform)

rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
rgb_img = cv2.resize(rgb_img, (224, 224))
rgb_img = np.float32(rgb_img) / 255
input_tensor = preprocess_image(rgb_img, mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

# If None, returns the map for the highest scoring category.
# Otherwise, targets the requested category.
target_category = None

# AblationCAM and ScoreCAM have batched implementations.
# You can override the internal batch size for faster computation.
cam.batch_size = 32
print(input_tensor.shape)
grayscale_cam = cam(input_tensor=input_tensor,
                    target_category=target_category,
                    eigen_smooth=args.eigen_smooth,
                    aug_smooth=args.aug_smooth)
input_tensor = input_tensor.cuda()
output = model(input_tensor)
output = nn.Softmax(dim=1)(output)
sorted,ind = torch.sort(output)
res_ind = ind[0][-1].item()
# Here grayscale_cam has only one image in the batch
grayscale_cam = grayscale_cam[0, :]

cam_image = show_cam_on_image(rgb_img, grayscale_cam)
print(label[res_ind])
cv2.imwrite(f'{args.method}_cam_swin.jpg', cam_image)`

My heatmap: gradcam_cam_swin gradcam_cam_swin

Is there something wrong with my code?

jacobgil commented 2 years ago

Can you please try with --method scorecam or --method ablationcam and post the results?

Grad-cam seems to give bad results for the swin-transformer. The result is very sensitive to the target_layer choice. For the swin_base_patch4_window7_224 model, This target_layer seems to give reasonable results, give it a try: target_layer = model.layers[-1].blocks[-1].norm2

jacobgil commented 2 years ago

Closing the issue for now, please re-open if the suggestion above doesn't help.