pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.92k stars 3.61k forks source link

[Roadmap] GNN Explainability Support πŸš€ #5520

Open RexYing opened 1 year ago

RexYing commented 1 year ago

πŸš€ The feature, motivation and pitch

Explainability has been an important component for users to probe model behavior, understand feature and structural importances, obtain new knowledge about the underlying task, and extract insights from a GNN model.

Over the years, many papers on GNN explainability have been published (some of which are integrated into PyG already), and many surveys, benchmakrs and evaluation frameworks have been proposed, such as the taxonomic survey and GraphFramEx multi-aspect evaluation framework. Some of these recent progress raise new challenges in terms of method, evaluation and visualization, outlined here.

Alternatives

No response

Additional context

There are other explainability functionalities that are still in relatively early stage, such as concept / class-wise explanations and counterfactual explanations. There are ongoing research projects that could potentially be integrated in future.

Padarn commented 1 year ago

This sounds pretty cool, I'd be keen to lend a hand :-)

FYI: There is an open PR that is quite related https://github.com/pyg-team/pytorch_geometric/pull/4507.

wsad1 commented 1 year ago

Thanks for adding the roadmap. One complexity with Explainability support for heterogeneous graphs is as follows: Methods like GNNExplainer work by setting a mask in the MessagePassing module https://github.com/pyg-team/pytorch_geometric/blob/d8d06e1b69bea4e71ab39c141aa00a17c32d72d9/torch_geometric/nn/models/explainer.py#L13. This might be an issue for methods like HanConv and HGTConv as internally they share a MessagePassing module across edge types, so we won't be able to set separate masks for each edge type. One way to resolve this is to refactor these gnns to have a different message module for each edge type.

Padarn commented 1 year ago

That makes sense @wsad1, I think for most common cases using something like to_hetero this won't be a problem, but I guess the models that are doing this today would need to be updated.

Possibly even just setting a 'explainability_disabled' on those models would be enough?

Another rough idea could be to support node and edge input dicts for MessagePassing: From HGTConv

# Iterate over edge-types:
for edge_type, edge_index in edge_index_dict.items():
    src_type, _, dst_type = edge_type
    edge_type = '__'.join(edge_type)

    a_rel = self.a_rel[edge_type]
    k = (k_dict[src_type].transpose(0, 1) @ a_rel).transpose(1, 0)

    m_rel = self.m_rel[edge_type]
    v = (v_dict[src_type].transpose(0, 1) @ m_rel).transpose(1, 0)

    # propagate_type: (k: Tensor, q: Tensor, v: Tensor, rel: Tensor)
    out = self.propagate(edge_index, k=k, q=q_dict[dst_type], v=v,
                         rel=self.p_rel[edge_type], size=None)
    out_dict[dst_type].append(out)

to become

out = self.propagate(edge_index_dict, k=k_dict...)
divyanshugit commented 1 year ago

@RexYing If possible, I would love to work on adding basic evaluation metrics. Please let me know if no one is working on it.

rusty1s commented 1 year ago

No one is working on this yet, so please feel free to go ahead. Much appreciated, thank you!

RexYing commented 1 year ago

@divyanshugit Happy to setup a meeting to discuss more in detail. I'll reach out to you on PyG Slack. As suggested by @ivaylobah , I'll break it down into several issues. The issue that you might be able to start is https://github.com/pyg-team/pytorch_geometric/issues/5628.

divyanshugit commented 1 year ago

@divyanshugit Happy to setup a meeting to discuss more in detail. I'll reach out to you on PyG Slack. As suggested by @ivaylobah , I'll break it down into several issues. The issue that you might be able to start is #5628.

Thank you so much for the reminder. It would be great to discuss this in detail.

fratajcz commented 1 year ago

I'd love to hear more about this. I am currently trying to make the captum explainability examples work for heterogeneous graphs by inputting a 2-dimensional edge_mask where the second dimension indicates the edge type. Then I override the explain_message function of the conv layer so that it can use this type information and only applies the edge_mask for the type that is processed at the moment. However, weird things happen and I get nonsense results (i.e. the "most influential" edges happen to be outside of the query node's supportive field). see also my bug report, but I think someone with more knowledge of the internals of the framework could make this work easily.

musicjae commented 1 year ago

I'm passionate to this plan. It's cool

daniel-unyi-42 commented 1 year ago

I'd be happy to work on this, especially these two:

rusty1s commented 1 year ago

Thanks @daniel-unyi-42. We are launching an explainability sprint with more fine-granular tasks soon. Please consider participating!