pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.39k stars 3.67k forks source link

GNNExplainer: hard_mask is True #7454

Closed avivko closed 1 year ago

avivko commented 1 year ago

🐛 Describe the bug

When running GNNExplainer as follows:

explainer = Explainer(
    model=loaded_model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs',
    ),
)

explanation = explainer(x=somegraph.x, edge_index=somegraph.edge_index)
print(f'Generated explanations in {explanation.available_explanations}')

You get the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[27], line 16
      1 explainer = Explainer(
      2     model=loaded_model,
      3     algorithm=GNNExplainer(epochs=200),
   (...)
     11     ),
     12 )
---> 16 explanation = explainer(x=somegraph.x, edge_index=somegraph.edge_index)
     17 print(f'Generated explanations in {explanation.available_explanations}')
     19 path = 'feature_importance.png'

File ~/repos/pytorch_geometric_egnnexp/torch_geometric/explain/explainer.py:198, in Explainer.__call__(self, x, edge_index, target, index, **kwargs)
    195 training = self.model.training
    196 self.model.eval()
--> 198 explanation = self.algorithm(
    199     self.model,
    200     x,
    201     edge_index,
    202     target=target,
    203     index=index,
    204     **kwargs,
    205 )
    207 self.model.train(training)
    209 # Add explainer objectives to the `Explanation` object:

File [...]/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/repos/pytorch_geometric/torch_geometric/explain/algorithm/gnn_explainer.py:77, in GNNExplainer.forward(self, model, x, edge_index, target, index, **kwargs)
     72     raise ValueError(f"Heterogeneous graphs not yet supported in "
     73                      f"'{self.__class__.__name__}'")
     75 self._train(model, x, edge_index, target=target, index=index, **kwargs)
---> 77 node_mask = self._post_process_mask(
     78     self.node_mask,
     79     self.hard_node_mask,
     80     apply_sigmoid=True,
     81 )
     82 edge_mask = self._post_process_mask(
     83     self.edge_mask,
     84     self.hard_edge_mask,
     85     apply_sigmoid=True,
     86 )
     88 self._clean_model(model)

File ~/repos/pytorch_geometric/torch_geometric/explain/algorithm/base.py:111, in ExplainerAlgorithm._post_process_mask(mask, hard_mask, apply_sigmoid)
    108 if apply_sigmoid:
    109     mask = mask.sigmoid()
--> 111 if hard_mask is not None and mask.size(0) == hard_mask.size(0):
    112     mask[~hard_mask] = 0.
    114 return mask

AttributeError: 'bool' object has no attribute 'size'

The problem seems to be that hard_mask is a bool and not a tensor.

It seems that self._train() initializes the masks via self._initialize_masks() and then defines the hard masks as follows:

            # In the first iteration, we collect the nodes and edges that are
            # involved into making the prediction. These are all the nodes and
            # edges with gradient != 0 (without regularization applied).
            if i == 0 and self.node_mask is not None:
                self.hard_node_mask = self.node_mask.grad != 0.0
            if i == 0 and self.edge_mask is not None:
                self.hard_edge_mask = self.edge_mask.grad != 0.0

This might be the reason why the hard masks get saved as bools, causing this error.

Here's a printout of the node mask / hard mask:

mask: tensor([[-1.8137, -1.8033, -1.8204,  ..., -1.8470, -1.7542, -1.7775],
        [-1.7110, -1.7935, -1.7723,  ..., -1.7946, -1.6870, -1.6377],
        [-1.7446, -1.9176, -1.7605,  ..., -1.7810, -1.8225, -1.8848],
        ...,
        [-1.7572, -1.7603, -1.7157,  ..., -1.7118, -1.6725, -1.7529],
        [-1.7567, -1.7795, -1.8137,  ..., -1.7881, -1.6930, -1.7331],
        [-1.8848, -1.8004, -1.8971,  ..., -1.8504, -1.8012, -1.8312]])
 hard_mask: True

I can try to submit a pull request with a fix if you would like me to.

Environment

rusty1s commented 1 year ago

Mh, let me think. self.node_mask.grad should be a tensor, so self.node_mask.grad != 0.0 should be a tensor as well. I assume in your case self.node_mask.grad is None, which converts that to a boolean. Do you know why in your case the gradient is None? I cannot reproduce this on my end.

avivko commented 1 year ago

@rusty1s Yes, in the first iteration of the epoch loop self.node/edge_mask are not None but self.node/edge_mask.grad are Noneand in the following iterations self.node/edge_mask.grad do become tensors once the self.hard_node/edge_mask end up being set to True. Let me know if you have any ideas/suggestions as to how to go about this

avivko commented 1 year ago

This might have something to do with this (optimizer.zero_grad(set_to_none=True): https://github.com/pytorch/pytorch/commit/b90496eef5665bc39828f6c1c522f399bcc62f3f However, setting set_to_none=False doesn't seem to solve the issue

rusty1s commented 1 year ago

Can you confirm that the examples/explain/gnn_explainer.py example works for you? I believe this might be an issue that either your nodes or your edges do not receive a gradient at all.

I pushed a more meaningful error message via https://github.com/pyg-team/pytorch_geometric/pull/7512.

avivko commented 1 year ago

It does and you are right -- It was because I was using a custom GNN that didn't use PyG's message passing propagate method, and thus the edges didn't receive a gradient