divelab / DIG

A library for graph deep learning research
https://diveintographs.readthedocs.io/
GNU General Public License v3.0
1.87k stars 283 forks source link

Tutorial issue #28

Closed gui-li closed 3 years ago

gui-li commented 3 years ago

The error comes out while executing your tutorial. The code block:

# --- Create data collector and explanation processor ---
from dig.xgraph.evaluation import XCollector, ExplanationProcessor
x_collector = XCollector()

index = -1
node_indices = torch.where(dataset[0].test_mask * dataset[0].y != 0)[0].tolist()
data = dataset[0]

from dig.xgraph.method.subgraphx import PlotUtils
from dig.xgraph.method.subgraphx import find_closest_node_result, k_hop_subgraph_with_default_whole_graph
plotutils = PlotUtils(dataset_name='ba_shapes')

# Visualization
max_nodes = 5
node_idx = node_indices[6]
print(f'explain graph node {node_idx}')
data.to(device)
logits = model(data.x, data.edge_index)
prediction = logits[node_idx].argmax(-1).item()

_, explanation_results, related_preds = \
    explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)

plotutils = PlotUtils(dataset_name='ba_shapes')
explainer.visualization(explanation_results,
                        prediction,
                        max_nodes=max_nodes,
                        plot_utils=plotutils,
                        y=data.y)

The error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-7e5867836373> in <module>
     20 
     21 _, explanation_results, related_preds = \
---> 22     explainer(data.x, data.edge_index, node_idx=node_idx, max_nodes=max_nodes)
     23 result = find_closest_node_result(explanation_results[prediction], max_nodes=max_nodes)
     24 

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in __call__(self, x, edge_index, **kwargs)
    671                 payoff_func = self.get_reward_func(value_func, node_idx=self.mcts_state_map.node_idx)
    672                 self.mcts_state_map.set_score_func(payoff_func)
--> 673                 results = self.mcts_state_map.mcts(verbose=False)
    674 
    675                 tree_node_x = find_closest_node_result(results, max_nodes=max_nodes)

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in mcts(self, verbose)
    465             print(f"The nodes in graph is {self.graph.number_of_nodes()}")
    466         for rollout_idx in range(self.n_rollout):
--> 467             self.mcts_rollout(self.root)
    468             if verbose:
    469                 print(f"At the {rollout_idx} rollout, {len(self.state_map)} states that have been explored.")

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in mcts_rollout(self, tree_node)
    450                     tree_node.children.append(new_node)
    451 
--> 452             scores = compute_scores(self.score_func, tree_node.children)
    453             for child, score in zip(tree_node.children, scores):
    454                 child.P = score

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/subgraphx.py in compute_scores(score_func, children)
    163     for child in children:
    164         if child.P == 0:
--> 165             score = score_func(child.coalition, child.data)
    166         else:
    167             score = child.P

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in mc_l_shapley(coalition, data, local_raduis, value_func, subgraph_building_method, sample_num)
    216     include_mask = np.stack(set_include_masks, axis=0)
    217     marginal_contributions = \
--> 218         marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
    219 
    220     mc_l_shapley_value = (marginal_contributions).mean().item()

~/anaconda3/envs/dig/lib/python3.8/site-packages/dig/xgraph/method/shapley.py in marginal_contribution(data, exclude_mask, include_mask, value_func, subgraph_build_func)
     73     marginal_contribution_list = []
     74 
---> 75     for exclude_data, include_data in dataloader:
     76         exclude_values = value_func(exclude_data)
     77         include_values = value_func(include_data)

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    557         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    558         if self._pin_memory:
--> 559             data = _utils.pin_memory.pin_memory(data)
    560         return data
    561 

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data)
     53         return type(data)(*(pin_memory(sample) for sample in data))
     54     elif isinstance(data, container_abcs.Sequence):
---> 55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
     57         return data.pin_memory()

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in <listcomp>(.0)
     53         return type(data)(*(pin_memory(sample) for sample in data))
     54     elif isinstance(data, container_abcs.Sequence):
---> 55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
     57         return data.pin_memory()

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py in pin_memory(data)
     55         return [pin_memory(sample) for sample in data]
     56     elif hasattr(data, "pin_memory"):
---> 57         return data.pin_memory()
     58     else:
     59         return data

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in pin_memory(self, *keys)
    363         If :obj:`*keys` is not given, the conversion is applied to all present
    364         attributes."""
--> 365         return self.apply(lambda x: x.pin_memory(), *keys)
    366 
    367     def debug(self):

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in apply(self, func, *keys)
    324         """
    325         for key, item in self(*keys):
--> 326             self[key] = self.__apply__(item, func)
    327         return self
    328 

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in __apply__(self, item, func)
    303     def __apply__(self, item, func):
    304         if torch.is_tensor(item):
--> 305             return func(item)
    306         elif isinstance(item, SparseTensor):
    307             # Not all apply methods are supported for `SparseTensor`, e.g.,

~/anaconda3/envs/dig/lib/python3.8/site-packages/torch_geometric/data/data.py in <lambda>(x)
    363         If :obj:`*keys` is not given, the conversion is applied to all present
    364         attributes."""
--> 365         return self.apply(lambda x: x.pin_memory(), *keys)
    366 
    367     def debug(self):

RuntimeError: cannot pin 'torch.cuda.LongTensor' only dense CPU tensors can be pinned

My installed packages:

# packages in environment at /home/*/anaconda3/envs/dig:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main
_openmp_mutex             4.5                       1_gnu
anyio                     3.1.0            py38h578d9bd_0    conda-forge
argon2-cffi               20.1.0           py38h497a2fe_2    conda-forge
ase                       3.21.1                   pypi_0    pypi
async_generator           1.10                       py_0    conda-forge
attrs                     21.2.0             pyhd8ed1ab_0    conda-forge
babel                     2.9.1              pyh44b312d_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                        py_2    conda-forge
backports.functools_lru_cache 1.6.4              pyhd8ed1ab_0    conda-forge
blas                      1.0                         mkl
bleach                    3.3.0              pyh44b312d_0    conda-forge
boost                     1.74.0           py38hc10631b_3    conda-forge
boost-cpp                 1.74.0               hc6e9bd1_3    conda-forge
brotlipy                  0.7.0           py38h497a2fe_1001    conda-forge
bzip2                     1.0.8                h7b6447c_0
ca-certificates           2021.5.30            ha878542_0    conda-forge
cairo                     1.16.0            h6cf1ce9_1008    conda-forge
captum                    0.2.0                    pypi_0    pypi
certifi                   2021.5.30        py38h578d9bd_0    conda-forge
cffi                      1.14.5           py38ha65f79e_0    conda-forge
chardet                   4.0.0            py38h578d9bd_1    conda-forge
cilog                     1.2.0                    pypi_0    pypi
cloudpickle               1.6.0                    pypi_0    pypi
cryptography              3.4.7            py38ha5dfef3_0    conda-forge
cudatoolkit               10.1.243             h6bb024c_0
cycler                    0.10.0                     py_2    conda-forge
decorator                 4.4.2                    pypi_0    pypi
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
dive-into-graphs          0.0.4                    pypi_0    pypi
entrypoints               0.3             pyhd8ed1ab_1003    conda-forge
et-xmlfile                1.1.0                    pypi_0    pypi
ffmpeg                    4.3                  hf484d3e_0    pytorch
fontconfig                2.13.1            hba837de_1005    conda-forge
freetype                  2.10.4               h5ab3b9f_0
gettext                   0.19.8.1          h0b5b191_1005    conda-forge
gmp                       6.2.1                h2531618_2
gnutls                    3.6.15               he1e5248_0
googledrivedownloader     0.4                      pypi_0    pypi
greenlet                  1.1.0            py38h709712a_0    conda-forge
h5py                      3.2.1                    pypi_0    pypi
icu                       68.1                 h58526e2_0    conda-forge
idna                      2.10               pyh9f0ad1d_0    conda-forge
importlib-metadata        4.5.0            py38h578d9bd_0    conda-forge
intel-openmp              2021.2.0           h06a4308_610
ipykernel                 5.5.5            py38hd0cf306_0    conda-forge
ipython                   7.24.1           py38hd0cf306_0    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
isodate                   0.6.0                    pypi_0    pypi
jedi                      0.18.0           py38h578d9bd_2    conda-forge
jinja2                    3.0.1              pyhd8ed1ab_0    conda-forge
joblib                    1.0.1                    pypi_0    pypi
jpeg                      9b                   h024ee3a_2
json5                     0.9.5              pyh9f0ad1d_0    conda-forge
jsonschema                3.2.0              pyhd8ed1ab_3    conda-forge
jupyter_client            6.1.12             pyhd8ed1ab_0    conda-forge
jupyter_core              4.7.1            py38h578d9bd_0    conda-forge
jupyter_server            1.8.0              pyhd8ed1ab_0    conda-forge
jupyterlab                3.0.16             pyhd8ed1ab_0    conda-forge
jupyterlab_pygments       0.1.2              pyh9f0ad1d_0    conda-forge
jupyterlab_server         2.6.0              pyhd8ed1ab_0    conda-forge
kiwisolver                1.3.1            py38h1fd1430_1    conda-forge
lame                      3.100                h7b6447c_0
lcms2                     2.12                 h3be6417_0
ld_impl_linux-64          2.35.1               h7274673_9
libffi                    3.3                  he6710b0_2
libgcc-ng                 9.3.0               h5101ec6_17
libglib                   2.68.3               h3e27bee_0    conda-forge
libgomp                   9.3.0               h5101ec6_17
libiconv                  1.16                 h516909a_0    conda-forge
libidn2                   2.3.1                h27cfd23_0
libpng                    1.6.37               hbc83047_0
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              9.3.0               hd4cf53a_17
libtasn1                  4.16.0               h27cfd23_0
libtiff                   4.2.0                h85742a9_0
libunistring              0.9.10               h27cfd23_0
libuuid                   2.32.1            h7f98852_1000    conda-forge
libuv                     1.40.0               h7b6447c_0
libwebp-base              1.2.0                h27cfd23_0
libxcb                    1.13              h7f98852_1003    conda-forge
libxml2                   2.9.12               h72842e0_0    conda-forge
llvmlite                  0.36.0                   pypi_0    pypi
lz4-c                     1.9.3                h2531618_0
markupsafe                2.0.1            py38h497a2fe_0    conda-forge
matplotlib-base           3.4.2            py38hcc49a3a_0    conda-forge
matplotlib-inline         0.1.2              pyhd8ed1ab_2    conda-forge
mistune                   0.8.4           py38h497a2fe_1003    conda-forge
mkl                       2021.2.0           h06a4308_296
mkl-service               2.3.0            py38h27cfd23_1
mkl_fft                   1.3.0            py38h42c9631_2
mkl_random                1.2.1            py38ha9443f7_2
mypy-extensions           0.4.3                    pypi_0    pypi
nbclassic                 0.3.1              pyhd8ed1ab_1    conda-forge
nbclient                  0.5.3              pyhd8ed1ab_0    conda-forge
nbconvert                 6.0.7            py38h578d9bd_3    conda-forge
nbformat                  5.1.3              pyhd8ed1ab_0    conda-forge
ncurses                   6.2                  he6710b0_1
nest-asyncio              1.5.1              pyhd8ed1ab_0    conda-forge
nettle                    3.7.3                hbbd107a_1
networkx                  2.5.1                    pypi_0    pypi
ninja                     1.10.2               hff7bd54_1
notebook                  6.4.0              pyha770c72_0    conda-forge
numba                     0.53.1                   pypi_0    pypi
numpy                     1.20.2           py38h2d18471_0
numpy-base                1.20.2           py38hfae3a4d_0
olefile                   0.46                       py_0
openh264                  2.1.0                hd408876_0
openpyxl                  3.0.7                    pypi_0    pypi
openssl                   1.1.1k               h7f98852_0    conda-forge
packaging                 20.9               pyh44b312d_0    conda-forge
pandas                    1.2.4                    pypi_0    pypi
pandoc                    2.14.0.1             h7f98852_0    conda-forge
pandocfilters             1.4.2                      py_1    conda-forge
parso                     0.8.2              pyhd8ed1ab_0    conda-forge
pcre                      8.44                 he1b5a44_0    conda-forge
pexpect                   4.8.0              pyh9f0ad1d_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    8.2.0            py38he98fc37_0
pip                       21.1.2           py38h06a4308_0
pixman                    0.40.0               h36c2ea0_0    conda-forge
prometheus_client         0.11.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.18             pyha770c72_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pycairo                   1.20.1           py38hf61ee4a_0    conda-forge
pycparser                 2.20               pyh9f0ad1d_2    conda-forge
pygments                  2.9.0              pyhd8ed1ab_0    conda-forge
pyopenssl                 20.0.1             pyhd8ed1ab_0    conda-forge
pyparsing                 2.4.7              pyh9f0ad1d_0    conda-forge
pyrsistent                0.17.3           py38h497a2fe_2    conda-forge
pysocks                   1.7.1            py38h578d9bd_3    conda-forge
python                    3.8.10               h12debd9_8
python-dateutil           2.8.1                      py_0    conda-forge
python-louvain            0.15                     pypi_0    pypi
python_abi                3.8                      1_cp38    conda-forge
pytorch                   1.8.1           py3.8_cuda10.1_cudnn7.6.3_0    pytorch
pytz                      2021.1             pyhd8ed1ab_0    conda-forge
pyzmq                     22.1.0           py38h2035c66_0    conda-forge
rdflib                    5.0.0                    pypi_0    pypi
rdkit                     2021.03.3        py38hf8acc3d_0    conda-forge
readline                  8.1                  h27cfd23_0
reportlab                 3.5.67           py38hadf75a6_0    conda-forge
requests                  2.25.1             pyhd3deb0d_0    conda-forge
scikit-learn              0.24.2                   pypi_0    pypi
scipy                     1.6.3                    pypi_0    pypi
send2trash                1.5.0                      py_0    conda-forge
setuptools                52.0.0           py38h06a4308_0
shap                      0.39.0                   pypi_0    pypi
six                       1.15.0           py38h06a4308_0
slicer                    0.0.7                    pypi_0    pypi
sniffio                   1.2.0            py38h578d9bd_1    conda-forge
sqlalchemy                1.4.18           py38h497a2fe_0    conda-forge
sqlite                    3.35.4               hdfb4753_0
tabulate                  0.8.9                    pypi_0    pypi
terminado                 0.10.1           py38h578d9bd_0    conda-forge
testpath                  0.5.0              pyhd8ed1ab_0    conda-forge
threadpoolctl             2.1.0                    pypi_0    pypi
tk                        8.6.10               hbc83047_0
torch-cluster             1.5.9                    pypi_0    pypi
torch-geometric           1.7.0                    pypi_0    pypi
torch-scatter             2.0.7                    pypi_0    pypi
torch-sparse              0.6.9                    pypi_0    pypi
torch-spline-conv         1.2.1                    pypi_0    pypi
torchaudio                0.8.1                      py38    pytorch
torchvision               0.9.1                py38_cu101    pytorch
tornado                   6.1              py38h497a2fe_1    conda-forge
tqdm                      4.61.0                   pypi_0    pypi
traitlets                 5.0.5                      py_0    conda-forge
typed-argument-parser     1.5.4                    pypi_0    pypi
typing-inspect            0.7.1                    pypi_0    pypi
typing_extensions         3.7.4.3            pyha847dfd_0
tzdata                    2020f                h52ac0ba_0
urllib3                   1.26.5             pyhd8ed1ab_0    conda-forge
wcwidth                   0.2.5              pyh9f0ad1d_2    conda-forge
webencodings              0.5.1                      py_1    conda-forge
websocket-client          0.57.0           py38h578d9bd_4    conda-forge
wheel                     0.36.2             pyhd3eb1b0_0
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.0.10               h7f98852_0    conda-forge
xorg-libsm                1.2.3             hd9c2040_1000    conda-forge
xorg-libx11               1.7.2                h7f98852_0    conda-forge
xorg-libxau               1.0.9                h7f98852_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h7f98852_1    conda-forge
xorg-libxrender           0.9.10            h7f98852_1003    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h7f98852_1002    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.5                h7b6447c_0
zeromq                    4.3.4                h9c3ff4c_0    conda-forge
zipp                      3.4.1              pyhd8ed1ab_0    conda-forge
zlib                      1.2.11               h7b6447c_3
zstd                      1.4.9                haebb681_0

I have installed the latest version of DIG from source.

Oceanusity commented 3 years ago

I have disabled the pin_memory flag in the Dataloader. I think it will solve the problem, and feel free to report problems.

gui-li commented 3 years ago

@Oceanusity Thanks again for your work. The library now works fine with the provided tutorial.