dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.34k stars 3k forks source link

[Feature Request] Support binary classification in GNNExplainer #3846

Open veenapaddy opened 2 years ago

veenapaddy commented 2 years ago

🚀 Feature

Support binary classifier GNNs(output dimension 1, i.e. a single score) in GNNExplainer

Motivation

GNNExplainer in DGL assumes multi-class classification as the objective of the classifier it is explaining. So if we naively feed it a binary classifier it will assume all data points belong to a single class, and fail to explain any single target node correctly. Workarounds are possible for a user who is aware of this, but require the user to change the structure of the GNN model that she/he is trying to explain from output dimension 1, to 2. This is inconvenient for a user who knows what they are doing, but for the user who does not dig into the source code will cause wrong results.

Alternatives

Workarounds by the user are a possible alternate, but less intuitive.

The code segment in question

logits = self.model(graph=sg, feat=h,
                                eweight=edge_mask.sigmoid(), **kwargs)
log_probs = logits.log_softmax(dim=-1)

page: https://docs.dgl.ai/en/latest/_modules/dgl/nn/pytorch/explain/gnnexplainer.html

avinashsai commented 2 years ago

@jermainewang Can I work on this? Please, provide more details.

mufeili commented 2 years ago

Hi @avinashsai , it will be great if you can work on that. Once you have a preliminary implementation, you can open a PR for that. We can then have more discussions in that PR and further polish the implementation.