pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.98k stars 3.62k forks source link

How can I solve the AttributeError: 'Partition' object has no attribute 'indptr' with the Flickr dataset #9251

Open yuhaoxu99 opened 4 months ago

yuhaoxu99 commented 4 months ago

🐛 Describe the bug

create Flickr data

import torch_geometric.datasets data_path = os.path.join(path, Flickr) T = ToSparseTensor() if to_sparse else lambda x: x dataset_class = getattr(torch_geometric.datasets, dataset_name) dataset = dataset_class(data_path, transform=T) processed_dir = dataset.processed_dir data = dataset[0] split_masks = {} split_masks["train"] = data.train_mask split_masks["valid"] = data.val_mask split_masks["test"] = data.test_mask x = data.x y = data.y

uses ClusterLoader create subgraph

from torch_geometric.data import ClusterData, ClusterLoader sample_size = max(1, int(args.batch_size / (data.num_nodes / args.num_parts))) cluster_data = ClusterData( data, num_parts=args.num_parts, recursive=False, save_dir=self.processed_dir ) train_loader = ClusterLoader( cluster_data, batch_size=sample_size, shuffle=True )

when I train the subgraph data, I need use the following to get the batch

for batch in train_loader:

and for this step, it will get the error: for batch in train_loader: File "D:\package\Anaconda\envs\Graph118\lib\site-packages\torch\utils\data\dataloader.py", line 630, in next data = self._next_data() File "D:\package\Anaconda\envs\Graph118\lib\site-packages\torch\utils\data\dataloader.py", line 674, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "D:\package\Anaconda\envs\Graph118\lib\site-packages\torch\utils\data_utils\fetch.py", line 54, in fetch return self.collate_fn(data) File "D:\package\Anaconda\envs\Graph118\lib\site-packages\torch_geometric\loader\cluster.py", line 263, in _collate global_indptr = self.cluster_data.partition.indptr AttributeError: 'Partition' object has no attribute 'indptr'

The above code is useful for another dataset, such as Reddit, but it doesn't work for Flickr. Please can you help me solve this problem?

Versions

Collecting environment information... PyTorch version: 2.1.0+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Could not collect GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect Libc version: N/A

Python version: 3.10.14 | packaged by Anaconda, Inc. | (main, Mar 21 2024, 16:20:14) [MSC v.1916 64 bit (AMD64)] (64-bit runtime) Python platform: Windows-10-10.0.19045-SP0 Is CUDA available: True CUDA runtime version: 12.1.66 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Quadro RTX 5000 Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU:

Revision=21767

Versions of relevant libraries: [pip3] numpy==1.26.3 [pip3] torch==2.1.0+cu118 [pip3] torch-cluster==1.6.3+pt21cu118 [pip3] torch_geometric==2.5.3 [pip3] torch-scatter==2.1.2+pt21cu118 [pip3] torch-sparse==0.6.18+pt21cu118 [pip3] torch-spline-conv==1.2.2+pt21cu118 [pip3] torchaudio==2.1.0+cu118 [pip3] torchdata==0.7.1 [pip3] torchvision==0.16.0+cu118 [conda] numpy 1.26.3 pypi_0 pypi [conda] torch 2.1.0+cu118 pypi_0 pypi [conda] torch-cluster 1.6.3+pt21cu118 pypi_0 pypi [conda] torch-geometric 2.5.3 pypi_0 pypi [conda] torch-scatter 2.1.2+pt21cu118 pypi_0 pypi [conda] torch-sparse 0.6.18+pt21cu118 pypi_0 pypi [conda] torch-spline-conv 1.2.2+pt21cu118 pypi_0 pypi [conda] torchaudio 2.1.0+cu118 pypi_0 pypi [conda] torchdata 0.7.1 pypi_0 pypi [conda] torchvision 0.16.0+cu118 pypi_0 pypi

rusty1s commented 4 months ago

The following works for me:

from torch_geometric.data import ClusterData, ClusterLoader
from torch_geometric.datasets import Flickr

dataset = Flickr('/tmp/Flickr')
data = dataset[0]

cluster_data = ClusterData(data, num_parts=10, recursive=False)
train_loader = ClusterLoader(
    cluster_data,
    batch_size=2,
    shuffle=True,
)

for batch in train_loader:
    print(batch)

Can you confirm?

yuhaoxu99 commented 4 months ago

The following works for me:

from torch_geometric.data import ClusterData, ClusterLoader
from torch_geometric.datasets import Flickr

dataset = Flickr('/tmp/Flickr')
data = dataset[0]

cluster_data = ClusterData(data, num_parts=10, recursive=False)
train_loader = ClusterLoader(
    cluster_data,
    batch_size=2,
    shuffle=True,
)

for batch in train_loader:
    print(batch)

Can you confirm?

Thank you for your reply! Now it works for me! And I really feel strange sometimes it can work, and sometimes it doesn't.