pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.65k stars 3.58k forks source link

Storing sparse tensors inside dataset causes runtime error #5416

Open hongwangg opened 1 year ago

hongwangg commented 1 year ago

🐛 Describe the bug

Minimal example to reproduce the error:

import torch
from torch_geometric.data import InMemoryDataset, HeteroData

class TestDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        data1 = HeteroData()
        data1['test1'].x = torch.tensor([1.0, 2, 3])
        data1['test2'].x = torch.sparse_coo_tensor((3,))

        data2 = HeteroData()
        data2['test1'].x = torch.tensor([1.0, 2, 3])
        data2['test2'].x = torch.sparse_coo_tensor((3,))

        data_list = [data1, data2]
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

dataset = TestDataset('../test_data/')
dataset[0]

Error message:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [1], in <cell line: 27>()
     24         torch.save((data, slices), self.processed_paths[0])
     26 dataset = TestDataset('../test_data/')
---> 27 dataset[0]

File ~\miniconda3\envs\pyg\lib\site-packages\torch_geometric\data\dataset.py:197, in Dataset.__getitem__(self, idx)
    187 r"""In case :obj:`idx` is of type integer, will return the data object
    188 at index :obj:`idx` (and transforms it in case :obj:`transform` is
    189 present).
    190 In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
    191 tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or
    192 bool, will return a subset of the dataset at the specified indices."""
    193 if (isinstance(idx, (int, np.integer))
    194         or (isinstance(idx, Tensor) and idx.dim() == 0)
    195         or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
--> 197     data = self.get(self.indices()[idx])
    198     data = data if self.transform is None else self.transform(data)
    199     return data

File ~\miniconda3\envs\pyg\lib\site-packages\torch_geometric\data\in_memory_dataset.py:84, in InMemoryDataset.get(self, idx)
     81 elif self._data_list[idx] is not None:
     82     return copy.copy(self._data_list[idx])
---> 84 data = separate(
     85     cls=self.data.__class__,
     86     batch=self.data,
     87     idx=idx,
     88     slice_dict=self.slices,
     89     decrement=False,
     90 )
     92 self._data_list[idx] = copy.copy(data)
     94 return data

File ~\miniconda3\envs\pyg\lib\site-packages\torch_geometric\data\separate.py:37, in separate(cls, batch, idx, slice_dict, inc_dict, decrement)
     35         slices = slice_dict[attr]
     36         incs = inc_dict[attr] if decrement else None
---> 37     data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
     38                                  incs, batch, batch_store, decrement)
     40 # The `num_nodes` attribute needs special treatment, as we cannot infer
     41 # the real number of nodes from the total number of nodes alone:
     42 if hasattr(batch_store, '_num_nodes'):

File ~\miniconda3\envs\pyg\lib\site-packages\torch_geometric\data\separate.py:65, in _separate(key, value, idx, slices, incs, batch, store, decrement)
     63 cat_dim = batch.__cat_dim__(key, value, store)
     64 start, end = int(slices[idx]), int(slices[idx + 1])
---> 65 value = value.narrow(cat_dim or 0, start, end - start)
     66 value = value.squeeze(0) if cat_dim is None else value
     67 if decrement and (incs.dim() > 1 or int(incs[idx]) != 0):

RuntimeError: Tensors of type SparseTensorImpl do not have strides

Environment

rusty1s commented 1 year ago

This is indeed currently not supported, but you can replace the PyTorch sparse_coo_tensor with the torch_sparse.SparseTensor one. We are working on better sparse_coo_tensor support in the next month.