pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.92k stars 3.61k forks source link

Failed to run the captum_explainability.py #4649

Open YOLO-jbc opened 2 years ago

YOLO-jbc commented 2 years ago

🐛 Describe the bug

Hello!!! I am trying to run your example captum_explainability.py. When it goes to

ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=target,
                            additional_forward_args=(data.x, data.edge_index),
                            internal_batch_size=1)

I got an error: IndexError: index 1 is out of bounds for dimension 0 with size 1

The whole error report are as follows:

Traceback (most recent call last):
  File "d:/Artificial_Intelligence/project/DL/Graph defense/Bacldoor graph defense/Backdoor_graph_defense/test.py", line 56, in <module>
    ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=target,
  File "D:\Artificial_Intelligence\Anaconda\lib\site-packages\captum\attr\_core\integrated_gradients.py", line 278, in attribute
  File "D:\Artificial_Intelligence\Anaconda\lib\site-packages\captum\attr\_utils\common.py", line 500, in _run_forward
    output = forward_func(
  File "D:\Artificial_Intelligence\Anaconda\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "D:\Artificial_Intelligence\Anaconda\lib\site-packages\torch_geometric\nn\models\explainer.py", line 64, in forward
    set_masks(self.model, mask.squeeze(0), args[1],
  File "D:\Artificial_Intelligence\Anaconda\lib\site-packages\torch_geometric\nn\models\explainer.py", line 16, in set_masks
    loop_mask = edge_index[0] != edge_index[1]
IndexError: index 1 is out of bounds for dimension 0 with size 1

Thanks a lot! :)

Environment

rusty1s commented 2 years ago

I cannot reproduce this. Are you operating on a custom dataset? What's your Captum version?

YOLO-jbc commented 2 years ago

Thanks for your reply, I am using Cora dataset by dataset = Planetoid('./datasets', 'Cora', transform=T.NormalizeFeatures()) . My Captum version is 0.2.0. Concretely, my code is just the same as code

rusty1s commented 2 years ago

Can you try to upgrade captum and test again?

pip install --upgrade captum
YOLO-jbc commented 2 years ago

It works!!!! Really appreciate your help. I never thought about the Captum's version. Wish you best life ;)