pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.29k stars 3.65k forks source link

[Explainability Evaluation] - Fidelity +/- metrics #5958

Closed BlazStojanovic closed 1 year ago

BlazStojanovic commented 1 year ago

🚀 The feature, motivation and pitch

Implement Fidelity +/- metrics (Fid+ and Fid-) for the purpose of evaluation explainability. This feature request is a part of broader explainability evaluation metrics (parent issue #5628) which will be crucial for the new revamp of explainability (see roadmap https://github.com/pyg-team/pytorch_geometric/issues/5520).

Brief overview of Fidelity+/-

Fidelity measures check the explanations for their faithfulness to the model. This means that checking that that important inputs (nodes/edges/node features) according to the explanation are discriminative to the model.

Formally fidelity is given with equations (from A.6 of [3]) image For details and explanations see references, particularly [2-3].

Implementation checklist

References

  1. Explainability Methods for Graph Convolutional Neural Networks
  2. Explainability in Graph Neural Networks: A Taxonomic Survey
  3. GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks
AbdullahWasTaken commented 1 year ago

I am implementing the fidelity measures and had doubts about interpreting specific terms.

Just to confirm, in the equations, $\hat{y}_i^{G_s}$ represents the predictions obtained over the masked input (edge_mask used as edge_weights, node_mask multiplied to each row of feature matrix x and node_feat_mask multiplied to each column of the feature matrix)?

Explanations can be formed by node_mask, node_feat_mask, and edge_mask. But in case one of these is missing, what should be the default behavior? Should the mask contain all ones or all zeros?

BlazStojanovic commented 1 year ago

Hey @AbdullahWasTaken, thanks for the question, the behaviour will depend on the input graph and which masks are missing! Let me try to clarify:

You have two types of predictions:

How do get the graphs? Well, as you pointed out our explanations come in forms of four masks (assume hard masks),

We need to be careful when obtaining the explanatory graph to include/exclude only the right nodes/edges/features (analogous for the complement graph, but with complementary mask), there are many cases to consider, so let me just highlight two:

  1. We have a full graph (nodes, edges, node features, edge features), but the Explanation only has an node_feature_mask. The connectivity of the graph doesn’t change due to this mask, we obtain the new graph by updating the node features $\mathbf XN = \mathbf M{NF} \odot \mathbf X$, where $\mathbf X$ are the features.
  2. Let’s have the same graph but we now have the node_mask alongside the node_feature_mask, the node mask removes both nodes and node features, but alongside this the edges also need to be adjusted as there cannot be edges between missing nodes.

In any case this should be handled inside Explanation, for now you can just assume that you can obtain the graphs as Explanation.get_explanation_graph() or something along these lines.

BlazStojanovic commented 1 year ago

@AbdullahWasTaken We’ve implemented some additional functionallity to Explanation, you can now use:

CXX1113 commented 1 year ago

In fact, I found that Fidelity metric in the GraphFramEx paper is wrong. image S should have the subscript i (that is, S_i), instead of using the same mask for the prediction of each i. In the code of GraphFramEx, S is actually regarded as S_i. The code link is follow: https://github.com/DS3Lab/GraphFramEx/blob/main/code/evaluate/fidelity.py#L69C1-L73C42

BlazStojanovic commented 1 year ago

Not absolutely sure what you mean with $S_i$ but the explanations are of course specific to each node, i.e. their explanations differ and so do the respective subgraphs we use to evaluate fidelity. In this case you are right that $S$ is actually regarded as $S_i$, and we treat it as such in PyG: https://github.com/pyg-team/pytorch_geometric/blob/40d4718713142bd60a9f65e4d66a80789290ae84/torch_geometric/explain/metric/fidelity.py#L75C4-L75C4

As for the notation in the paper, I think the authors just wanted to avoid clutter and not use index twice,

image

I think the definition from the paper indicates quite clearly that the masked graphs (and complements) are node specific.