suinleelab / path_explain

A repository for explaining feature attributions and feature interactions in deep neural networks.
MIT License
185 stars 28 forks source link

Extend torch interactions to higher dimensions #10

Open chenblair opened 3 years ago

chenblair commented 3 years ago

Seems like PathExplainerTorch.interactions only supports 2D tensors, unlike the TensorFlow version.

What are the bottlenecks to supporting arbitrarily-sized tensors, and how difficult would the change be? If it's not too bad, I would be interested in making a PR myself.

Would love any input @jjanizek, thanks!

jjanizek commented 3 years ago

Hi -- so the issue is that PyTorch doesn't currently have a function that's a nice equivalent of TF's batch_jacobian function. This function basically just helps avoid a bunch of redundant calculations and is nicely parallelized. There's some discussion of this issue here.

By combining the sort of slow, python for-loop dependent implementation suggested by Gsunshine with some of the reshaping suggested here, a PyTorch version of the batch_jacobian could be done pretty easily, although potentially this function might be slow. Once you had this, you'd then replace the part of the PyTorch code in the "interactions" function that manually iterates through each input feature (and assumes a 2D batch x features tensor) with the batch_jacobian function you implemented.

Let me know if that sounds reasonable or like too much of a pain to be interesting! It's also possible that by searching around, you may find that someone else has made more progress on the batch_jacobian function in PyTorch recently (e.g. see discussion here).

jjanizek commented 3 years ago

Also, I actually think that probably it would be easy to support higher dimensional input tensors if the user only wanted to calculate all interactions with a single feature, which we currently do with the "interaction_index" argument. In this case, the hessian is the same size as the input because we index into a particular gradient. So this might be an even easier starting place.

jjanizek commented 3 years ago

Ok, as an update, we don't yet have a good way to support arbitrary dimension tensors in torch, but I did add an Embedding Explainer that should work with NLP models:

https://github.com/suinleelab/path_explain/blob/master/path_explain/explainers/embedding_explainer_torch.py