pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.81k stars 488 forks source link

how to use layergradcam for a multi-label application? #887

Open amandalucasp opened 2 years ago

amandalucasp commented 2 years ago

I'm working with a multi-label image dataset. My inputs have the following shape: torch.Size([3, 224, 224]); and my targets are all 1x33 tensors one-hot encoded, as in the following example: tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0')

I'm trying to use the LayerGradCam() in the following manner:

layer_gradcam = LayerGradCam(linear_classifier, linear_classifier.sigm)
for batch_inp, batch_target in test_loader:
    batch_inp = batch_inp.cuda(non_blocking=True)
    batch_target = batch_target.cuda(non_blocking=True)
    for inp, target in zip(batch_inp, batch_target):
        attributions_lgc = layer_gradcam.attribute(inputs=inp, target=target)

and keep getting the following error:

RuntimeError: mat1 dim 1 must match mat2 dim 0

I'm currently trying to implement the solution proposed in https://github.com/pytorch/captum/issues/171 but I get a python error when trying to create the tuple.

layer_gradcam = LayerGradCam(linear_classifier, linear_classifier.sigm)
for batch_inp, batch_target in test_loader:
    batch_inp = batch_inp.cuda(non_blocking=True)
    batch_target = batch_target.cuda(non_blocking=True)
    for inp, target in zip(batch_inp, batch_target):
        targets_idx = torch.nonzero(target)
        inp = inp.unsqueeze(0)
        for idx in targets_idx:
            target = tuple(target.cpu().numpy(), idx)
            attributions_lgc = layer_gradcam.attribute(inputs=inp, target=target)

The error I'm getting is that I cannot create a tuple using 2 arguments instead of 1.

I also tried passing the index to atributte():

 attributions_lgc = layer_gradcam.attribute(inputs=inp, target=idx)

In this case, I have an input of inp: torch.Size([1, 3, 224, 224]) and a target of torch.Size([1]). But then I get the same error I was getting on the first snippet of code.

Any suggestions on how to solve this?

Thanks!

99warriors commented 2 years ago

Hi @amandalucasp, thank you for the question! Do you get an error if you do a forward pass, i.e. linear_classifier(inp)?

amandalucasp commented 2 years ago

Hi @99warriors! I don't get any errors doing a forward pass, as I was able to successfully train my classifier.