Open bwdeng20 opened 3 years ago
Mh, I feel like this is a bit more tricky, since it requires us to implement that logic for all GNN ops that make use of message_and_aggregate
. I will need to think about it, but since GNNExplainer
only operates on the L-hop neighborhood around each node, I don't think memory is that much of a problem.
🚀 Feature
GNN Explainer support in `MessagePassing` class for `SparseTensor` adjacency matrix. ## MotivationSparseTensor
-based aggregation is more memory-efficient and designing new customized models founded on it seems to make the process more concise thanedge_index
style.Injecting the
edge_mask
seems to be able to implemented via a little hacky way: given an instance ofSparseTensor
, we can first get the edge weights and multiply them withedge_mask
to get the injected weights. Then invokeSparseTensor.set_value()
to get aSparseTensor
with injected weights.Additional context
This seems feasible according to my experiments about
SparseTensor
, in which I find thatset_value
doesn't violate the computation graph of pytorch.