divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.84k stars 281 forks source link

PGExplainer - Visualization for Graphs Explanations #140

Closed matekenya closed 2 years ago

matekenya commented 2 years ago

I'd like to create visualisations for the explanations of graph predictions. I have followed along with the documentation and I am currently meeting the following problem.

from dig.xgraph.method.pgexplainer import PlotUtils
plotutils = PlotUtils(dataset_name='bbbp', is_show=True)
data = get_datasets()['bbbp_dataset'][0]
x_collector = XCollector()
model = get_models(dim_classes, dim_node)['GIN_2L']
explainer = PGExplainer(model, in_channels=256, device=device, explain_graph=True)
with torch.no_grad():
    _, masks, related_preds = explainer(data.x, data.edge_index)
    x_collector.collect_data(masks, related_preds)

print('Data: ', data)
print('Plotutils: ', plotutils)

explainer.visualization(data=data, edge_mask=masks[0], top_k=3, plot_utils=plotutils)

image

Oceanusity commented 2 years ago

Thank you for your issue. I have updated the visualization code, and add the hyper-parameter of x.

matekenya commented 2 years ago

Hello @Oceanusity, it appears this issue has not been resolved, I am still not able to produce the graph classification visualisation.