pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.11k stars 3.63k forks source link

NeighborLoader breaks when cat_dim is a tuple #8709

Open viktor-ktorvi opened 9 months ago

viktor-ktorvi commented 9 months ago

🐛 Describe the bug

Hi,

I batch my sparse matrices block diagonally so cat_dim is a tuple. I want to use NeighborLoader but it expects cat_dim to not be a tuple internally and throws an error.

I should add that the block diagonal batching and using the NeighborLoader are unrelated tasks. I usually work with smaller graphs and load them with the usual DataLoader. I just so happen to have a larger graph that I want to make lots of smaller graphs from, so I'm using the NeighborLoader for that. Strange use case but seems reasonable enough.

Traceback (most recent call last):
  File "C:\Users\todos\PycharmProjects\mlpf\examples\graph_sampling_bug.py", line 59, in <module>
    main()
  File "C:\Users\todos\PycharmProjects\mlpf\examples\graph_sampling_bug.py", line 55, in main
    subgraph = next(iter(subgraph_loader))
  File "C:\Users\todos\anaconda3\envs\mlpfenv_cu121_pyg_newest\lib\site-packages\torch\utils\data\dataloader.py", line 630, in __next__
    data = self._next_data()
  File "C:\Users\todos\anaconda3\envs\mlpfenv_cu121_pyg_newest\lib\site-packages\torch\utils\data\dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\todos\anaconda3\envs\mlpfenv_cu121_pyg_newest\lib\site-packages\torch\utils\data\_utils\fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "C:\Users\todos\anaconda3\envs\mlpfenv_cu121_pyg_newest\lib\site-packages\torch_geometric\loader\node_loader.py", line 148, in collate_fn
    out = self.filter_fn(out)
  File "C:\Users\todos\anaconda3\envs\mlpfenv_cu121_pyg_newest\lib\site-packages\torch_geometric\loader\node_loader.py", line 165, in filter_fn
    data = filter_data(  #
  File "C:\Users\todos\anaconda3\envs\mlpfenv_cu121_pyg_newest\lib\site-packages\torch_geometric\loader\utils.py", line 156, in filter_data
    filter_node_store_(data._store, out._store, node)
  File "C:\Users\todos\anaconda3\envs\mlpfenv_cu121_pyg_newest\lib\site-packages\torch_geometric\loader\utils.py", line 92, in filter_node_store_
    elif store.is_node_attr(key):
  File "C:\Users\todos\anaconda3\envs\mlpfenv_cu121_pyg_newest\lib\site-packages\torch_geometric\data\storage.py", line 808, in is_node_attr
    if value.shape[cat_dim] != num_nodes:
TypeError: tuple indices must be integers or slices, not tuple

The following snippet recreates the error:

from typing import Tuple

import numpy as np
import scipy
import torch
from scipy.sparse import csr_matrix
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import is_sparse, erdos_renyi_graph

def csr_to_sparse_tensor(csr: csr_matrix, size: Tuple[int, ...]):
    """
    Convert a scipy csr matrix to a torch csr tensor.

    :param csr: csr matrix
    :param size: Size.
    :return: csr tensor.
    """
    return torch.sparse_csr_tensor(crow_indices=torch.LongTensor(csr.indptr),
                                   col_indices=torch.LongTensor(csr.indices),
                                   values=torch.tensor(csr.data),
                                   size=size)

class DataSparseBlockDiagonal(Data):
    def __cat_dim__(self, key, value, *args, **kwargs):
        """
        Batch sparse matrices block diagonally.
        :param key:
        :param value:
        :param args:
        :param kwargs:
        :return:
        """
        if is_sparse(value):
            return 0, 1  # concatenate block-diagonally
        return super().__cat_dim__(key, value, *args, **kwargs)

def main():
    num_nodes = 42
    edge_index = erdos_renyi_graph(num_nodes, 0.2, directed=True)
    num_rows, num_columns = edge_index.shape[1], num_nodes

    random_csr_tensor = csr_to_sparse_tensor(csr=scipy.sparse.random(num_rows, num_columns, format='csr', dtype=np.float32),
                                             size=(num_rows, num_columns))

    data = DataSparseBlockDiagonal(
        edge_index=edge_index,
        sparse_tensor=random_csr_tensor
    )

    subgraph_loader = NeighborLoader(data, num_neighbors=[-1], replace=True, subgraph_type="bidirectional")
    subgraph = next(iter(subgraph_loader))

if __name__ == "__main__":
    main()

I feel like this is an issue.

Versions

I'm on Windows so I couldn't get the above to run but my pip freeze is:

pip freeze
aiohttp==3.9.1
aiosignal==1.3.1
antlr4-python3-runtime==4.9.3
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
certifi==2023.7.22
charset-normalizer==3.3.0
click==8.1.7
colorama==0.4.6
contourpy==1.1.1
cycler==0.12.1
deepdiff==6.6.0
docker-pycreds==0.4.0
filelock==3.12.4
fonttools==4.43.1
frozenlist==1.4.1
fsspec==2023.9.2
gitdb==4.0.10
GitPython==3.1.37
hydra-core==1.3.2
idna==3.4
importlib-resources==6.1.0
Jinja2==3.1.2
joblib==1.3.2
kiwisolver==1.4.5
lightning-utilities==0.10.0
MarkupSafe==2.1.3
matplotlib==3.7.1
mpmath==1.3.0
multidict==6.0.4
networkx==3.1
numpy==1.24.3
omegaconf==2.3.0
ordered-set==4.1.0
packaging==23.2
pandapower==2.12.1
pandas==1.4.4
pathtools==0.1.2
Pillow==10.0.1
protobuf==4.24.4
psutil==5.9.5
pyg-nightly==2.4.0.dev20231028
pyparsing==3.1.1
PYPOWER==5.1.16
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
requests==2.31.0
scikit-learn==1.2.2
scipy==1.10.1
sentry-sdk==1.32.0
setproctitle==1.3.3
simbench==1.4.0
six==1.16.0
smmap==5.0.1
sympy==1.12
threadpoolctl==3.2.0
torch==2.1.0+cu121
torch-cluster==1.6.3+pt21cu121
torch-scatter==2.1.2+pt21cu121
torch-sparse==0.6.18+pt21cu121
torch-spline-conv==1.2.2+pt21cu121
torch_geometric @ git+https://github.com/pyg-team/pytorch_geometric.git@0e526ab546b135dd8d5fbd55174d74da1e4028be
torchaudio==2.1.0+cu121
torchmetrics==1.2.1
torchvision==0.16.0+cu121
tqdm==4.65.0
typing_extensions==4.8.0
urllib3==2.0.6
wandb==0.15.12
yarl==1.9.4
zipp==3.17.0
rusty1s commented 9 months ago

Currently, sparse tensors are not supported yet as part of feature fetching during NeighborLoader. We can fix the issues with cat_dim being a tuple, but it will still crash downstream since torch.sparse does not support index_select yet :(