Open scottdd204 opened 7 months ago
Hi, the prediction results is missed, and should add codes
node_idx = node_indices[20]
, logits = model(data.x, data.edge_index)
and
prediction = logits[node_idx].argmax(-1).item()
.
Thanks for pointing it out.
Or you can run the visualization for subgraphx first, and then you will have the prediction variable.
See tutorials/KDD2022/xgraph_code_tutorial.ipynb
gnnexplainer_related_preds = \ gnnexplainer(data.x, data.edge_index, sparsity=sparsity, num_classes=num_classes, node_idx=node_idx) ax, G = gnnexplainer.visualize_graph(node_idx=node_idx, edge_index=data.edge_index, edge_mask=gnnexplainer_related_preds[1][prediction], y=data.y)
And therefore there is no example of visualisation for GNNExplainer.