pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.92k stars 3.61k forks source link

GNN Explainer Support for SparseTensor #1671

Open bwdeng20 opened 3 years ago

bwdeng20 commented 3 years ago

🚀 Feature

GNN Explainer support in `MessagePassing` class for `SparseTensor` adjacency matrix. ## Motivation

SparseTensor-based aggregation is more memory-efficient and designing new customized models founded on it seems to make the process more concise than edge_index style.

Injecting the edge_mask seems to be able to implemented via a little hacky way: given an instance of SparseTensor, we can first get the edge weights and multiply them with edge_mask to get the injected weights. Then invoke SparseTensor.set_value() to get a SparseTensor with injected weights.

    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor:
         if self.__explain__:
            edge_weight =adj_t.storage.value()
            adj_t = adj_t.set_value(edge_weight * self.__edge_mask__)
        # do something

Additional context

This seems feasible according to my experiments about SparseTensor, in which I find that set_value doesn't violate the computation graph of pytorch.

rusty1s commented 3 years ago

Mh, I feel like this is a bit more tricky, since it requires us to implement that logic for all GNN ops that make use of message_and_aggregate. I will need to think about it, but since GNNExplainer only operates on the L-hop neighborhood around each node, I don't think memory is that much of a problem.