frgfm / torch-cam

Class activation maps for your PyTorch models (CAM, Grad-CAM, Grad-CAM++, Smooth Grad-CAM++, Score-CAM, SS-CAM, IS-CAM, XGrad-CAM, Layer-CAM)
https://frgfm.github.io/torch-cam/
Apache License 2.0
1.98k stars 203 forks source link

Resnet Score-CAM Error: Only integer tensors of a single element can be converted to an index #264

Open neeleshbisht99 opened 1 month ago

neeleshbisht99 commented 1 month ago

First of all, thank you so much @frgfm for this great library (saves us a ton of time and is very accurate). Great work, Keep it going!

I'm trying to generate CAMs for a 3D Resnet-34 model through the library. For Grad-CAM and Grad-CAM++ it worked really well. But with ScoreCAM, I was stuck.

Below is the code I used for ScoreCAM:

@staticmethod
   def compute_score_cam(img_tensor, model, last_conv_layer_name='module.layer4', target_class_idx=1):
        model.eval()

        score_cam = ScoreCAM(model, last_conv_layer_name)

        with torch.no_grad():
            output, _ = model(img_tensor)

        score_cams = score_cam(class_idx=target_class_idx)

        heatmap = score_cams[0].cpu().numpy().squeeze()

        heatmap = np.transpose(heatmap, (2, 1, 0))

        return heatmap

Below is the traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], [line 11](vscode-notebook-cell:?execution_count=5&line=11)
      [6](vscode-notebook-cell:?execution_count=5&line=6) ref_class_idx = 0  # Reference class
      [9](vscode-notebook-cell:?execution_count=5&line=9) ##The CAMs below are obtained through "torch-cam" 3rd party package.
     [10](vscode-notebook-cell:?execution_count=5&line=10) #FIXME
---> [11](vscode-notebook-cell:?execution_count=5&line=11) torch_cam_score_cam_heatmap = TorchCAM.compute_score_cam(sample, model, last_conv_layer_name_with_module, target_class_idx)
     [12](vscode-notebook-cell:?execution_count=5&line=12) plt.matshow(np.squeeze(torch_cam_score_cam_heatmap[:, :, 3]))
     [13](vscode-notebook-cell:?execution_count=5&line=13) plt.show()

File /local/scratch/v_neelesh_bisht/3d-cnn/cams/resnet_cams/torch_cam.py:34, in TorchCAM.compute_score_cam(img_tensor, model, last_conv_layer_name, target_class_idx)
     [32](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/local/scratch/v_neelesh_bisht/3d-cnn/cams/resnet_cams/torch_cam.py:32) with torch.no_grad():
     [33](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/local/scratch/v_neelesh_bisht/3d-cnn/cams/resnet_cams/torch_cam.py:33)     output, _ = model(img_tensor)
---> [34](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/local/scratch/v_neelesh_bisht/3d-cnn/cams/resnet_cams/torch_cam.py:34) score_cams = score_cam(class_idx=target_class_idx)
     [35](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/local/scratch/v_neelesh_bisht/3d-cnn/cams/resnet_cams/torch_cam.py:35) heatmap = score_cams[0].cpu().numpy().squeeze()
     [36](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/local/scratch/v_neelesh_bisht/3d-cnn/cams/resnet_cams/torch_cam.py:36) heatmap = np.transpose(heatmap, (2, 1, 0))

File ~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:169, in _CAM.__call__(self, class_idx, scores, normalized, **kwargs)
    [166](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:166) self._precheck(class_idx, scores)
    [168](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:168) # Compute CAM
--> [169](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:169) return self.compute_cams(class_idx, scores, normalized, **kwargs)

File ~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:193, in _CAM.compute_cams(self, class_idx, scores, normalized, **kwargs)
    [178](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:178) """Compute the CAM for a specific output class.
    [179](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:179) 
    [180](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:180) Args:
   (...)
    [190](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:190)         the k-th element of the input batch for class index equal to the k-th element of `class_idx`.
    [191](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:191) """
    [192](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:192) # Get map weight & unsqueeze it
--> [193](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:193) weights = self._get_weights(class_idx, scores, **kwargs)
    [195](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:195) cams: List[Tensor] = []
    [197](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/core.py:197) with torch.no_grad():

File ~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File ~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:217, in ScoreCAM._get_weights(self, class_idx, *args)
    [214](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:214) origin_mode = self.model.training
    [215](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:215) self.model.eval()
--> [217](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:217) weights: List[Tensor] = self._get_score_weights(upsampled_a, class_idx)
    [219](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:219) # Reenable hook updates
    [220](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:220) self._hooks_enabled = True

File ~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    [112](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
    [113](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
    [114](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:114)     with ctx_factory():
--> [115](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torch/utils/_contextlib.py:115)         return func(*args, **kwargs)

File ~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:179, in ScoreCAM._get_score_weights(self, activations, class_idx)
    [176](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:176) _slice = slice(_idx * self.bs, min((_idx + 1) * self.bs, weights[idx].numel()))
    [177](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:177) # Get the softmax probabilities of the target class
    [178](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:178) # (*, M)
--> [179](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:179) cic = self.model(scored_input[_slice]) - logits[idcs[_slice]]
    [180](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:180) if isinstance(class_idx, int):
    [181](https://vscode-remote+ssh-002dremote-002bxulab-002dgpu1-002epc-002ecc-002ecmu-002eedu.vscode-resource.vscode-cdn.net/shared/home/v_neelesh_bisht/local_scratch/3d-cnn/~/local_scratch/cmu_research/lib/python3.10/site-packages/torchcam/methods/activation.py:181)     weights[idx][_slice] = cic[:, class_idx]

TypeError: only integer tensors of a single element can be converted to an index

Can anyone please help ?

frgfm commented 1 month ago

Hey @neeleshbisht99 :wave:

Happy to help but I'd need a minimal runnable snippet to reproduce this. The traceback here mentioned code in your notebook apparently that isn't displayed here, could you share a self containable snippet with all the imports and definitions that yields this error please?