pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.39k stars 3.67k forks source link

AssertionError with Captum and Attentive FP #6956

Closed Takaogahara closed 4 months ago

Takaogahara commented 1 year ago

🐛 Describe the bug

I'm receiving an AssertionError when explaining node, edge, and node_and_edge with Captum on AttentiveFP.

After investigating, I saw that line 153 of attentive_fp.py changes the edge_index size, which may trigger the error on line 555 at message_passing.py.

I adapted the examples from AttentiveFP and Captum to reproduce the described error.

import torch
import os.path as osp
import torch.nn as nn
from captum.attr import IntegratedGradients
from torch_geometric.nn import AttentiveFP, to_captum_model
from torch_geometric.datasets import MoleculeNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = 'bbbp'
path = osp.join(osp.dirname(osp.realpath(__file__)), 'MoleculeNet')
dataset = MoleculeNet(path, dataset)

model = AttentiveFP(in_channels=dataset.num_node_features,
                    hidden_channels=64,
                    out_channels=1,
                    edge_dim=dataset.num_edge_features,
                    num_layers=1,
                    num_timesteps=1,
                    dropout=0.2).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn = nn.BCEWithLogitsLoss()

data = dataset[0].to(device)
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
dummy_batch = torch.zeros(x.shape[0], dtype=int, device=device)
x = x.float()

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    logit = model(x=x, edge_index=edge_index,
                  edge_attr=edge_attr, batch=dummy_batch)

    loss = loss_fn(logit, data.y)
    loss.backward()
    optimizer.step()

target = int(data.y)
target *= -1

# Edge explainability
# ===================
captum_model = to_captum_model(model, mask_type='edge', output_idx=None)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
attr_edge = ig.attribute(edge_mask.unsqueeze(0),
                         target=target,
                         additional_forward_args=(x, edge_index,
                                                  edge_attr, dummy_batch),
                         internal_batch_size=1)

# Node and edge explainability
# ============================
captum_model = to_captum_model(model, mask_type='node_and_edge',
                               output_idx=None)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
inputs = (x.unsqueeze(0), edge_mask.unsqueeze(0))
attr_node, attr_edge = ig.attribute(inputs, target=target,
                                    additional_forward_args=(edge_index,
                                                             edge_attr,
                                                             dummy_batch),
                                    internal_batch_size=1)

# Node explainability
# ===================
captum_model = to_captum_model(model, mask_type='node', output_idx=None)

ig = IntegratedGradients(captum_model)
attr_node = ig.attribute(x.unsqueeze(0), target=target,
                         additional_forward_args=(edge_index,
                                                  edge_attr,
                                                  dummy_batch),
                         internal_batch_size=1)

This is the Traceback I received.

Traceback (most recent call last):
  File "<string>", line 6, in <module>
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/captum/log/__init__.py", line 35, in wrapper
    return func(*args, **kwargs)
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/captum/attr/_core/integrated_gradients.py", line 273, in attribute
    attributions = _batch_attribution(
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/captum/attr/_utils/batching.py", line 78, in _batch_attribution
    current_attr = attr_method._attribute(
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/captum/attr/_core/integrated_gradients.py", line 350, in _attribute
    grads = self.gradient_func(
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/captum/_utils/gradient.py", line 112, in compute_gradients
    outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/captum/_utils/common.py", line 456, in _run_forward
    output = forward_func(
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/torch_geometric/nn/models/captum.py", line 54, in forward
    x = self.model(*args)
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/torch_geometric/nn/models/attentive_fp.py", line 157, in forward
    h = F.elu_(self.mol_conv((x, out), edge_index))
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/torch_geometric/nn/conv/gat_conv.py", line 241, in forward
    out = self.propagate(edge_index, x=x, alpha=alpha, size=size)
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 446, in propagate
    out = self.explain_message(out, **explain_msg_kwargs)
  File "/home/takaogahara/virtualenvs/gnn-toolkit/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 559, in explain_message
    assert inputs.size(self.node_dim) == edge_mask.size(0)
AssertionError

Environment

rusty1s commented 1 year ago

Yeah, this is currently expected since all our explainers assume a shared edge_index representation across all utilized GNN layers.

Takaogahara commented 1 year ago

Is there any workaround or "fix" that can be made?

rusty1s commented 1 year ago

The only fix I can think of would be to manually disable explaining message passing layers for which you know the edge_index it operates on is different from the edge_index input. You can do this by manually hacking this into torch_geometric/explain/algorithm/utils.py in set_masks.

Takaogahara commented 1 year ago

I'm sorry for not getting back to you sooner. Other priorities demanded my attention in March, but I'm actively working to resolve this problem.

I'm using the code described in discussion #7702, which uses the explainer module.

I tried changing torch_geometric/explain/algorithm/utils.py as suggested, but to my surprise, the function set_masks is not called.

I looked for another similar function, but I couldn't find it :(