pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.96k stars 499 forks source link

Prediction doesn't match with sum of attribution scores #700

Closed JinwooMcLee closed 3 years ago

JinwooMcLee commented 3 years ago

Hi, I made multi-modal binary classification model. First modal is made of transformer encoder, and the other modal is just made of feed-forward network.

Model architecture code is as follows.

class InterpretMultiModal(nn.Module):
    def __init__(self, bert_backbone):
        super(InterpretMultiModal, self).__init__()
        self.bert_backbone = bert_backbone
        self.meta_linear1 = nn.Linear(len(class_array) + len(fieldid_array), 128) #1216-dim to 128-dim
        self.meta_linear2 = nn.Linear(128, 16)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(
            in_features=64+16, 
            out_features=2 #Double output to use captum
        )

    def forward(self, seq, meta, attention_mask):
        bert_logit = self.bert_backbone(seq, attention_mask=attention_mask).logits #64-dim

        meta_lin1 = self.relu(self.meta_linear1(meta))
        meta_lin2 = self.relu(self.meta_linear2(meta_lin1))
        fusion = torch.cat([bert_logit, meta_lin2], dim=1)
        logits = self.fc(fusion)

        return logits

But as I calculate model's attribution with layer integrated gradients for each modality, sum of each modal's attributions doesn't match with model's prediction.

For example, model predicts 1 (Positive class) but sum of attribution equals 1.0056 against target being positive class. In other cases model predicts 0 (Negative class) but sum of attribution equals 1.0606, 1.2314, 1.1387, ... against target being positive class, which is higher than previous case.

Attribution calculation is as follows.

interpretable_model = ModelInputWrapper(model)

lig = LayerIntegratedGradients(interpretable_model,
                                [interpretable_model.module.bert_backbone.bert.embeddings.word_embeddings, interpretable_model.input_maps['meta']])

#Calculate attribution against being predicted as bot
attributions = lig.attribute(inputs=(_input, _meta_info), target=torch.ones(size=_label.shape, dtype=torch.int64, device=device).squeeze(), additional_forward_args=_attention_mask, return_convergence_delta=True)

text_attributions = attributions[0][0].sum(dim=2).squeeze(0)
meta_attributions = attributions[0][1].squeeze(0)

text_attributions.sum() + meta_attribution.sum()

Am I interpreting the attribution in a wrong way? Or is attribution calculations is done wrong?

Any help would be really appreciated. Thanks!

bilalsal commented 3 years ago

Hi @JinwooMcLee,

apologies for the late response.

It might be less useful to compare attribution sums computed for different inputs. The attribution computed by Captum intends to give you an idea about how likely is an input feature to be part of the reason for the model to predict a specific output (e.g. the positive or the negative class). The sum of the attributions computed for an input might be high if several input features could "vote" for to the positive class, even if the model ultimately predicts the negative class. If you try to compute the attribution with target=torch.zeros(size=_label.shape), you might have an idea which of the input features instead "vote" for the negative class. You can sum-up the attributions for these features and try to compare the sum with the sum you had under target=torch.ones(size=_label.shape) for the exact same input. Captum is not actually designed to facilitate such comparison either, but it could be more informative than comparing attribution sums computed for different inputs.

Hope this helps

JinwooMcLee commented 3 years ago

Thank you for detailed explanation, @bilalsal.

I'm closing this issue. Really appreciated!