pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.92k stars 496 forks source link

Request: example with multilabel attribution #171

Closed eburling closed 4 years ago

eburling commented 4 years ago

The provided vision examples and documentation are excellent for single-class classification, but I am struggling to implement a multi-label use case.

For my use case, I use a single channel image of a cell nucleus as input. The target is a tensor the describes whether or not the cell was positive for each of 22 different protein markers, e.g. tensor([0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=torch.float64) ...that is, each cell can be positive for multiple markers, not only one. This is a simple multi-label classification task, where my model is the boilerplate torchvision.models.resnet18 with a custom final layer that accommodates the desired output.

I use the CIFAR vision example as a starting point as follows:

image

But I get AssertionError: Tensor target dimension torch.Size([22]) is not valid. I see from the docstring for saliency.attribute that targets/outputs with with greater than two dimensions should be passed as tuples, but when I pass tuple(labels[ind]) instead, I get AssertionError: Cannot choose target column with output shape torch.Size([1, 22])..

Ideally, I'd like to set up an AttributionVisualizer that looks like the following mock-up:

image

...where I can click each element of the prediction (e.g. CK19) and see the corresponding attribution image for that marker.

Any chance that a multi-label classification example like this could be supplied?

Much thanks!

NarineK commented 4 years ago

Hi @eburling, thank you for the question. Insights should ideally be able to support multi-label. That was our original plan.
In the multi-label case could you have a for-loop that iterates over the indices that are ones and passes the indices of ones as target?

Will something like this work ? target = tuple(labels[ind], index_where_cell_positive)

In this case target = tuple(labels[ind], 2) or target = tuple(labels[ind], 7) In the above case you need to do a forward pass to get the positions of 1s first.

This is not ideal but from other side if we try to attribute to all classes at once and if there are many classes that can be time consuming ...

vivekmig commented 4 years ago

Hi @eburling, as @NarineK mentioned, the right thing to do for obtaining the attribution for a particular class is to pass the index independently for each target class, which gives you the importance of each input pixel for the prediction of that particular protein marker.

Alternatively, if you want to attribute all the markers together, essentially asking the importance of each input pixel for all the target markers together, you could also do something like this:

def model_wrapper(inputs, targets):
    output = model(inputs)
    # element-wise multiply outputs with one-hot encoded targets 
    # and compute sum of each row
    # This sums the prediction for all markers which exist in the cell
    return torch.sum(output * targets, dim=0)

# To compute attributions
saliency = Saliency(model_wrapper)
# We pass targets as an additional arg since it is an input to the function, 
# but we don't want attributions for it.
# Targets are not necessary since the output per example is now a scalar.
attr = saliency.attribute(inputs, additional_forward_args=target)

For many of the methods, these output attributions should effectively equal to the sum of the attributions for each of the target markers. But if you need them independently, it is necessary to pass each class index independently.

In general, we need to compute attributions with respect to a scalar value per example, which causes this issue. The documentation regarding tuples applies when the model output is more than 2D, in this case the model output is still 2D.

As @NarineK mentioned, Insights doesn't currently support passing multiple targets, but we will look into adding this in the future. For now, as a workaround, you can pass a single target index as the label for each batch, but the UI would still let you choose each predicted marker and see the attribution for that marker. This would not show all the true labels / marker names in the same view, but it would show the predictions for each and the attribution for each marker can be visualized.

eburling commented 4 years ago

@vivekmig and @NarineK, thank you both for your responses! They were sufficient for my use case.