pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.68k stars 3.59k forks source link

Batch explanation with Captum #5917

Open RBendias opened 1 year ago

RBendias commented 1 year ago

🚀 The feature, motivation and pitch

Currently, we can use to_captum to explain one prediction example only. However, it would be good to be able to explain predictions in batches as well. Currently, this does not work because Captum assumes that the dimension 0 of the inputs is equal to the number of prediction examples. For batch explanations, the target vector and the input vector are split based on the first dimension and passed to the forward function.

One way to solve this is to provide a method to_captum_data that transforms the data, given batch size and output indices, to the required captum format, in addition to having to_captum_model. The data transformation method is also needed for heterogeneous support. Optionally we could also provide both functionalities in to_captum.

Alternatives

No response

Additional context

No response

IlyaTyagin commented 1 year ago

Thank you very much for opening this issue, this is exactly what I was about to ask.

@RBendias I understand that this pull request is not merged yet, but could you please provide an example for batch captum explanation? That would be very helpful.

RBendias commented 1 year ago

Hi @IlyaTyagin. We plan to move to_captum to pyg.explain and provide a Captum ExplainerAlgorithm class. Hence, I would add an example script with the implementation of the Captum ExplainerAlgorithm.

IlyaTyagin commented 1 year ago

Hi @RBendias , thank you for your reply! Sounds good, do you have an ETA on this?

RBendias commented 1 year ago

Most likely, I will be working on this in January.