divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.82k stars 281 forks source link

Fix some bugs of accessing items in dictionaries. #188

Closed Michael1015198808 closed 1 year ago

Michael1015198808 commented 1 year ago

Fix several places that treat dictionaries' items as object fields.

Oceanusity commented 1 year ago

Hello, thank you for your report. Since the elements in the results are MCTSNode class https://github.com/divelab/DIG/blob/8d7b020bb7146ce4215d287364354a637ff076c9/dig/xgraph/method/subgraphx.py#L389. In this class, each attribution is used as an object and therefore we will maintain this code. Thank you again for your contributions.

Michael1015198808 commented 1 year ago

It seems that the __call__ method of SubgraphX returns wrong type that made me thought List is the expected type. In __call__, the explanation_results appends items from explain (line 855) While in line 818 of explain, the item is changed into List by method write_from_MCTSNode_list. Since there are so many if statements and I'm not familiar of your code, I'm not sure if simply removing line 818 could help. Should I make an issue about this bug?

Michael1015198808 commented 1 year ago

Here is a short Python script (which is shorten from the original example in dig's documentation) to reproduce such bug, you'll get AttributeError like image

import torch
from dig.xgraph.models import GCN_2l
from dig.xgraph.method import SubgraphX
from dig.xgraph.method.subgraphx import find_closest_node_result
naive_x = torch.tensor([
    [11.],
    [45.],
    [14.],
], dtype=torch.float32)
naive_edge_index = torch.tensor([
    [0, 1, 2],
    [1, 2, 0],
])

model = GCN_2l(model_level='node', dim_node=1, dim_hidden=300, num_classes=4)
explainer = SubgraphX(model, num_classes=4, device="cpu", explain_graph=False,
                        reward_method='nc_mc_l_shapley')
max_nodes = 3
node_idx = 0 # Choose any index
logits = model(naive_x, naive_edge_index)
prediction = logits[node_idx].argmax(-1).item()

_, explanation_results, related_preds = \
    explainer(naive_x, naive_edge_index, node_idx=node_idx, max_nodes=max_nodes)
result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)