Closed Michael1015198808 closed 1 year ago
Hello, thank you for your report. Since the elements in the results are MCTSNode class https://github.com/divelab/DIG/blob/8d7b020bb7146ce4215d287364354a637ff076c9/dig/xgraph/method/subgraphx.py#L389. In this class, each attribution is used as an object and therefore we will maintain this code. Thank you again for your contributions.
It seems that the __call__
method of SubgraphX
returns wrong type that made me thought List
is the expected type.
In __call__
, the explanation_results
appends items from explain
(line 855)
While in line 818 of explain
, the item is changed into List
by method write_from_MCTSNode_list
.
Since there are so many if
statements and I'm not familiar of your code, I'm not sure if simply removing line 818 could help.
Should I make an issue about this bug?
Here is a short Python script (which is shorten from the original example in dig's documentation) to reproduce such bug, you'll get AttributeError
like
import torch
from dig.xgraph.models import GCN_2l
from dig.xgraph.method import SubgraphX
from dig.xgraph.method.subgraphx import find_closest_node_result
naive_x = torch.tensor([
[11.],
[45.],
[14.],
], dtype=torch.float32)
naive_edge_index = torch.tensor([
[0, 1, 2],
[1, 2, 0],
])
model = GCN_2l(model_level='node', dim_node=1, dim_hidden=300, num_classes=4)
explainer = SubgraphX(model, num_classes=4, device="cpu", explain_graph=False,
reward_method='nc_mc_l_shapley')
max_nodes = 3
node_idx = 0 # Choose any index
logits = model(naive_x, naive_edge_index)
prediction = logits[node_idx].argmax(-1).item()
_, explanation_results, related_preds = \
explainer(naive_x, naive_edge_index, node_idx=node_idx, max_nodes=max_nodes)
result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)
Fix several places that treat dictionaries' items as object fields.