dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.34k stars 3k forks source link

Integration of Captum.ai in DGL #4435

Open sophiakrix opened 2 years ago

sophiakrix commented 2 years ago

🚀 Feature

Integration of explainability methods for the data format of DGL required in mini-batch training.

Motivation

Right now it is not possible to use captum for DGL in heterogeneous graphs for link prediction. There is already a forum post where a basic usage of captum is shown for dgl in a homogeneous setting for node classification (https://discuss.dgl.ai/t/explainability-using-saliency-and-integrated-gradients-captum/2278/3). Now it is needed to extend this to heterogeneous graphs and link prediction.

Alternatives

Switching from DGL to pytorch geometric.

Additional Context

Pytorch geometric created an Explainer class that has a function that converts models of their class to models that can be used with captum (https://pytorch-geometric.readthedocs.io/en/2.0.4/_modules/torch_geometric/nn/models/explainer.html). Something similar is needed for DGL since captum expects input tensors for the forward function, but in a heterogeneous graph we have input features that are a dictionary mapping the node type to the feature tensor as in this example here (https://docs.dgl.ai/en/0.6.x/guide/minibatch-link.html):

model = Model(in_features, hidden_features, out_features, num_classes, etypes)
model = model.cuda()
opt = torch.optim.Adam(model.parameters())

for input_nodes, positive_graph, negative_graph, blocks in dataloader:
    blocks = [b.to(torch.device('cuda')) for b in blocks]
    positive_graph = positive_graph.to(torch.device('cuda'))
    negative_graph = negative_graph.to(torch.device('cuda'))
    input_features = blocks[0].srcdata['features']
    pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
BarclayII commented 2 years ago

Good suggestion! I think you are requesting quite a few things simultaneously. So I'm separating them:

As per the solution, I put it in your discussion forum post https://discuss.dgl.ai/t/captum-ai-explainability-for-heterogeneous-graphs/3135/3. We can leave the discussion about the solution itself there.

sophiakrix commented 2 years ago

Great, thanks @BarclayII ! That sounds like a good plan.