pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.92k stars 3.61k forks source link

Explainer Model is incompatible with FiLMConv Layer #5658

Open fratajcz opened 1 year ago

fratajcz commented 1 year ago

🐛 Describe the bug

Hi!

I use the Explainer that integrates Captum as described in the example as follows:

edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

captum_model = to_captum(model, mask_type='node_and_edge',
                            output_idx=output_idx)

ig = IntegratedGradients(captum_model)

ig_attr_node, ig_attr_edge = ig.attribute(
                (data.x.float().unsqueeze(0), edge_mask.unsqueeze(0)),
                additional_forward_args=(data.edge_index), internal_batch_size=1)

edge_index is a SparseTensor that also contains the information about the edge types (since FiLMConv is for multigraphs).

However, this raises an error because the edge mask is 2-dimensional:

Traceback (most recent call last):
  File "speos/explanation_dummy_film.py", line 108, in <module>
    additional_forward_args=(data.edge_index), internal_batch_size=1)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/log/__init__.py", line 35, in wrapper
    return func(*args, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_core/integrated_gradients.py", line 282, in attribute
    method=method,
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_utils/batching.py", line 79, in _batch_attribution
    **kwargs, n_steps=batch_steps, step_sizes_and_alphas=(step_sizes, alphas)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_core/integrated_gradients.py", line 354, in _attribute
    additional_forward_args=input_additional_args,
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/_utils/gradient.py", line 112, in compute_gradients
    outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/_utils/common.py", line 459, in _run_forward
    else inputs
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/models/explainer.py", line 78, in forward
    x = self.model(mask[0], *args)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/florin.ratajczak_pyg/tmph6keus6p.py", line 24, in forward
    x = self.module_6(x, edge_index)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/film_conv.py", line 138, in forward
    x=lin(x[0]), beta=beta, gamma=gamma, size=None)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 335, in propagate
    edge_mask = torch.cat([edge_mask, loop], dim=0)
RuntimeError: Tensors must have same number of dimensions: got 2 and 1 

if I remove the .unsqueeze(0) from the edge mask to get the requested dimension I get an Error from the explainer class:

Traceback (most recent call last):
  File "speos/explanation_dummy_film.py", line 108, in <module>
    additional_forward_args=(data.edge_index), internal_batch_size=1)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/log/__init__.py", line 35, in wrapper
    return func(*args, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_core/integrated_gradients.py", line 282, in attribute
    method=method,
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_utils/batching.py", line 79, in _batch_attribution
    **kwargs, n_steps=batch_steps, step_sizes_and_alphas=(step_sizes, alphas)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/attr/_core/integrated_gradients.py", line 354, in _attribute
    additional_forward_args=input_additional_args,
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/_utils/gradient.py", line 112, in compute_gradients
    outputs = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/captum/_utils/common.py", line 459, in _run_forward
    else inputs
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/app/miniconda/envs/pyg/lib/python3.7/site-packages/torch_geometric/nn/models/explainer.py", line 59, in forward
    assert args[0].shape[0] == 1, "Dimension 0 of input should be 1"
AssertionError: Dimension 0 of input should be 1

The whole thing works with TAGConv and GCNConv layer, so I expect that the culprit is the FiLMConv layer. I will try hacking the layer implementation to see if I can conditionally squeeze the mask.

Environment

fratajcz commented 1 year ago

also happens in the most up to date version in pip, 2.1.0.post1, but then it is an identical line in line 494, in explain_message.

fratajcz commented 1 year ago

So, the lines in question are:

if inputs.size(self.node_dim) != edge_mask.size(0):
    edge_mask = edge_mask[self._loop_mask]
    loop = edge_mask.new_ones(size_i)
    edge_mask = torch.cat([edge_mask, loop], dim=0)
    assert inputs.size(self.node_dim) == edge_mask.size(0)

When I check this, inputs.size(self.node_dim)evaluates to 158962, which is the number of edges of one of the adjacencies I feed into the network. I don't know why it does that. data.num_edges and edge_mask.shape(before unsqueezing gives the correct results of 4268876 and data.num_nodes gives 16852. So where does the 158962 come from?

fratajcz commented 1 year ago

I think I see where this error is coming from.

My edge_index is built as follows:

SparseTensor(row=tensor([    0,     0,     0,  ..., 16851, 16851, 16851]),
             col=tensor([   54,   721,  5041,  ..., 16561, 16573, 16676]),
             val=tensor([31., 31.,  0.,  ..., 31., 31.,  0.]),
             size=(16852, 16852), nnz=4268876, density=1.50%)

where the val value encodes from which adjacency the edge is coming. As it happens, the adjacency with value 0 has exactly 158962 edges, so it tries to apply the edge_mask for all edges (4268876) to the edges from the first adjacency (158962) and fails. How is edge_mask supposed to be formatted in case we have multiple types of edges?

fratajcz commented 1 year ago

I have tried overriding the MessagePassing.explain_message() to account for the fact that the edge_mask is passed just once but the edge types are processed iteratively. The edge_mask I pass in is now 2-dimensional, with the second dimension holding the information of the edge type, similar to val in the SparseTensorshown in the comment above. I have added a small if clause that tests if we have a 2-dimensional edge_mask and then plucks out the edge mask for the edge type that is being processed right now. To my luck, MessagePassing handles edge types in an ascending order (from 0 to x where x is the last edge type), so I can just increment the type I am looking for with each iteration.

def explain_message(self, inputs: Tensor, size_i: int) -> Tensor:
        # NOTE Replace this method in custom explainers per message-passing
        # layer to customize how messages shall be explained, e.g., via:
        # conv.explain_message = explain_message.__get__(conv, MessagePassing)
        # see stackoverflow.com: 394770/override-a-method-at-instance-level

        edge_mask = self._edge_mask

        if edge_mask is None:
            raise ValueError(f"Could not find a pre-defined 'edge_mask' as "
                             f"part of {self.__class__.__name__}.")

        # BEGIN ADDED CODE

        if len(edge_mask.shape) > 1:
            if not hasattr(self, "current_type"):
                self.current_type = 0

            values, types = torch.tensor_split(edge_mask, 2, dim=1)  # seperate edge_mask and edge_type again
            unique_types = torch.unique(types)
            actual_type = unique_types[self.current_type]
            edge_mask = values[types == actual_type]
            self.current_type += 1
            if actual_type == types.max():
                self.current_type = 0

            # END ADDED CODE

        if self._apply_sigmoid:
            edge_mask = edge_mask.sigmoid()

        # Some ops add self-loops to `edge_index`. We need to do the same for
        # `edge_mask` (but do not train these entries).

        if inputs.size(self.node_dim) != edge_mask.size(0):
            edge_mask = edge_mask[self._loop_mask]
            loop = edge_mask.new_ones(size_i)
            edge_mask = torch.cat([edge_mask.squeeze(0), loop], dim=0)
            #print(inputs.size(self.node_dim))
            assert inputs.size(self.node_dim) == edge_mask.size(0)

        size = [1] * inputs.dim()
        size[self.node_dim] = -1
        return inputs * edge_mask.view(size)

This runs fine, but the result edge attributions are nonsense. When visualizing it with explainer.visualize_subgraph(), the most influential edges are actually not part of the query node's subgraph.

rusty1s commented 1 year ago

Really sorry for the late reply, but I appreciate your detailed description. You are right that FilmConv currently fails due to the iterative application of self.propagate. We have some plans to revisit this soon via https://github.com/pyg-team/pytorch_geometric/issues/5630.

I think the cleanest approach to fix this though would be to get rid of the of the for-loop altogether. This should be doable by the usage of HeteroLinear. Let me know if you want to work on this.