pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.43k stars 3.67k forks source link

Captum `DeepLift` attibution method throws `AssertionError` #6389

Open IlyaTyagin opened 1 year ago

IlyaTyagin commented 1 year ago

πŸ› Describe the bug

When I'm trying to use DeepLift explainabiliy method from captum, I'm getting the AssertionError related to dimensionality of the input mask.

Code to reproduce the error is taken from the captum_explainability example:

Training part:

import os.path as osp

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from captum.attr import IntegratedGradients, DeepLift

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import Explainer, GCNConv, to_captum_model

dataset = 'Cora'
path = osp.join('..', 'data', 'Planetoid')
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    log_logits = model(data.x, data.edge_index)
    loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

Integrated Gradients works just fine:

output_idx = 10
target = int(data.y[output_idx])

captum_model = to_captum_model(model, mask_type='edge', output_idx=output_idx)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=target,
                            additional_forward_args=(data.x, data.edge_index),
                            internal_batch_size=1)
> tensor([[0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float64,
       grad_fn=<AddBackward0>)

DeepLift part (doesn't work):

dl = DeepLift(captum_model)
dl_attr_edge = dl.attribute(edge_mask.unsqueeze(0), target=target,
                            additional_forward_args=(data.x, data.edge_index),
                            )
> AssertionError: Dimension 0 of input should be 1

Full traceback:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Input In [15], in <cell line: 2>()
      1 dl = DeepLift(captum_model)
----> 2 dl_attr_edge = dl.attribute(edge_mask.unsqueeze(0), target=target,
      3                             additional_forward_args=(data.x, data.edge_index),
      4                             )

File /lustre/acslab/users/2288/anaconda3/envs/pyg_0822/lib/python3.8/site-packages/captum/log/__init__.py:35, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
     33 @wraps(func)
     34 def wrapper(*args, **kwargs):
---> 35     return func(*args, **kwargs)

File /lustre/acslab/users/2288/anaconda3/envs/pyg_0822/lib/python3.8/site-packages/captum/attr/_core/deep_lift.py:362, in DeepLift.attribute(self, inputs, baselines, target, additional_forward_args, return_convergence_delta, custom_attribution_func)
    352 expanded_target = _expand_target(
    353     target, 2, expansion_type=ExpansionTypes.repeat
    354 )
    356 wrapped_forward_func = self._construct_forward_func(
    357     self.model,
    358     (inputs, baselines),
    359     expanded_target,
    360     additional_forward_args,
    361 )
--> 362 gradients = self.gradient_func(wrapped_forward_func, inputs)
    363 if custom_attribution_func is None:
    364     if self.multiplies_by_inputs:

File /lustre/acslab/users/2288/anaconda3/envs/pyg_0822/lib/python3.8/site-packages/captum/_utils/gradient.py:112, in compute_gradients(forward_fn, inputs, target_ind, additional_forward_args)
     94 r"""
     95 Computes gradients of the output with respect to inputs for an
     96 arbitrary forward function.
   (...)
    108                 arguments) if no additional arguments are required
    109 """
    110 with torch.autograd.set_grad_enabled(True):
    111     # runs forward pass
--> 112     outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
    113     assert outputs[0].numel() == 1, (
    114         "Target not provided when necessary, cannot"
    115         " take gradient with respect to multiple outputs."
    116     )
    117     # torch.unbind(forward_out) is a list of scalar tensor tuples and
    118     # contains batch_size * #steps elements

File /lustre/acslab/users/2288/anaconda3/envs/pyg_0822/lib/python3.8/site-packages/captum/_utils/common.py:448, in _run_forward(forward_func, inputs, target, additional_forward_args)
    446 forward_func_args = signature(forward_func).parameters
    447 if len(forward_func_args) == 0:
--> 448     output = forward_func()
    449     return output if target is None else _select_targets(output, target)
    451 # make everything a tuple so that it is easy to unpack without
    452 # using if-statements

File /lustre/acslab/users/2288/anaconda3/envs/pyg_0822/lib/python3.8/site-packages/captum/attr/_core/deep_lift.py:401, in DeepLift._construct_forward_func.<locals>.forward_fn()
    400 def forward_fn():
--> 401     model_out = _run_forward(
    402         forward_func, inputs, None, additional_forward_args
    403     )
    404     return _select_targets(
    405         torch.cat((model_out[:, 0], model_out[:, 1])), target
    406     )

File /lustre/acslab/users/2288/anaconda3/envs/pyg_0822/lib/python3.8/site-packages/captum/_utils/common.py:456, in _run_forward(forward_func, inputs, target, additional_forward_args)
    453 inputs = _format_input(inputs)
    454 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 456 output = forward_func(
    457     *(*inputs, *additional_forward_args)
    458     if additional_forward_args is not None
    459     else inputs
    460 )
    461 return _select_targets(output, target)

File /lustre/acslab/users/2288/anaconda3/envs/pyg_0822/lib/python3.8/site-packages/torch/nn/modules/module.py:1148, in Module._call_impl(self, *input, **kwargs)
   1145     bw_hook = hooks.BackwardHook(self, full_backward_hooks)
   1146     input = bw_hook.setup_input_hook(input)
-> 1148 result = forward_call(*input, **kwargs)
   1149 if _global_forward_hooks or self._forward_hooks:
   1150     for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

File /lustre/acslab/users/2288/anaconda3/envs/pyg_0822/lib/python3.8/site-packages/torch_geometric/nn/models/captum.py:36, in CaptumModel.forward(self, mask, *args)
     31 """"""
     32 # The mask tensor, which comes from Captum's attribution methods,
     33 # contains the number of samples in dimension 0. Since we are
     34 # working with only one sample, we squeeze the tensors belows.
---> 36 assert mask.shape[0] == 1, "Dimension 0 of input should be 1"
     37 if self.mask_type == "edge":
     38     assert len(args) >= 2, "Expects at least x and edge_index as args."

AssertionError: Dimension 0 of input should be 1

Environment

rusty1s commented 1 year ago

cc @RBendias

RBendias commented 1 year ago

Only the methods Saliency, InputXGradient, Deconvolution, FeatureAblation, ShapleyValueSampling, IntegratedGradients, GradientShap, Occlusion, GuidedBackprop, KernelShap, and Lime work at the moment. For DeepLift we need batch support which we are currently working on.

IlyaTyagin commented 1 year ago

Got it, thanks. I'm particularly interested to check DeepLift because it's claimed to run faster and produce comparable to IntegratedGradients results from captum description. I'm running large-scale explainability experiments, so the runtime part is crucial.

EmilieDel commented 1 year ago

Hello, Just to mentioned that I am also interested in this feature (specifically GradCAM). Do you have an idea when this could be handled ? Thank you very much

rusty1s commented 1 year ago

We are trying to make CaptumExplainer feature-complete till PyG 2.3 (March 21).

EmilieDel commented 1 year ago

Hello! Sorry to bother again with this, what are the news here ? πŸ˜„

rusty1s commented 1 year ago

You can find the set of currently supported explainer methods here. Sadly, we didn't manage to integrate all of them in our PyG 2.3 release, but definitely want to extend it. Any help appreciated.