Open yuhaoxu99 opened 4 months ago
The following works for me:
from torch_geometric.data import ClusterData, ClusterLoader
from torch_geometric.datasets import Flickr
dataset = Flickr('/tmp/Flickr')
data = dataset[0]
cluster_data = ClusterData(data, num_parts=10, recursive=False)
train_loader = ClusterLoader(
cluster_data,
batch_size=2,
shuffle=True,
)
for batch in train_loader:
print(batch)
Can you confirm?
The following works for me:
from torch_geometric.data import ClusterData, ClusterLoader from torch_geometric.datasets import Flickr dataset = Flickr('/tmp/Flickr') data = dataset[0] cluster_data = ClusterData(data, num_parts=10, recursive=False) train_loader = ClusterLoader( cluster_data, batch_size=2, shuffle=True, ) for batch in train_loader: print(batch)
Can you confirm?
Thank you for your reply! Now it works for me! And I really feel strange sometimes it can work, and sometimes it doesn't.
🐛 Describe the bug
create Flickr data
import torch_geometric.datasets data_path = os.path.join(path, Flickr) T = ToSparseTensor() if to_sparse else lambda x: x dataset_class = getattr(torch_geometric.datasets, dataset_name) dataset = dataset_class(data_path, transform=T) processed_dir = dataset.processed_dir data = dataset[0] split_masks = {} split_masks["train"] = data.train_mask split_masks["valid"] = data.val_mask split_masks["test"] = data.test_mask x = data.x y = data.y
uses ClusterLoader create subgraph
from torch_geometric.data import ClusterData, ClusterLoader sample_size = max(1, int(args.batch_size / (data.num_nodes / args.num_parts))) cluster_data = ClusterData( data, num_parts=args.num_parts, recursive=False, save_dir=self.processed_dir ) train_loader = ClusterLoader( cluster_data, batch_size=sample_size, shuffle=True )
when I train the subgraph data, I need use the following to get the batch
for batch in train_loader:
and for this step, it will get the error: for batch in train_loader: File "D:\package\Anaconda\envs\Graph118\lib\site-packages\torch\utils\data\dataloader.py", line 630, in next data = self._next_data() File "D:\package\Anaconda\envs\Graph118\lib\site-packages\torch\utils\data\dataloader.py", line 674, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "D:\package\Anaconda\envs\Graph118\lib\site-packages\torch\utils\data_utils\fetch.py", line 54, in fetch return self.collate_fn(data) File "D:\package\Anaconda\envs\Graph118\lib\site-packages\torch_geometric\loader\cluster.py", line 263, in _collate global_indptr = self.cluster_data.partition.indptr AttributeError: 'Partition' object has no attribute 'indptr'
The above code is useful for another dataset, such as Reddit, but it doesn't work for Flickr. Please can you help me solve this problem?
Versions
Collecting environment information... PyTorch version: 2.1.0+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A
OS: Could not collect GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect Libc version: N/A
Python version: 3.10.14 | packaged by Anaconda, Inc. | (main, Mar 21 2024, 16:20:14) [MSC v.1916 64 bit (AMD64)] (64-bit runtime) Python platform: Windows-10-10.0.19045-SP0 Is CUDA available: True CUDA runtime version: 12.1.66 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: Quadro RTX 5000 Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU:
Revision=21767
Versions of relevant libraries: [pip3] numpy==1.26.3 [pip3] torch==2.1.0+cu118 [pip3] torch-cluster==1.6.3+pt21cu118 [pip3] torch_geometric==2.5.3 [pip3] torch-scatter==2.1.2+pt21cu118 [pip3] torch-sparse==0.6.18+pt21cu118 [pip3] torch-spline-conv==1.2.2+pt21cu118 [pip3] torchaudio==2.1.0+cu118 [pip3] torchdata==0.7.1 [pip3] torchvision==0.16.0+cu118 [conda] numpy 1.26.3 pypi_0 pypi [conda] torch 2.1.0+cu118 pypi_0 pypi [conda] torch-cluster 1.6.3+pt21cu118 pypi_0 pypi [conda] torch-geometric 2.5.3 pypi_0 pypi [conda] torch-scatter 2.1.2+pt21cu118 pypi_0 pypi [conda] torch-sparse 0.6.18+pt21cu118 pypi_0 pypi [conda] torch-spline-conv 1.2.2+pt21cu118 pypi_0 pypi [conda] torchaudio 2.1.0+cu118 pypi_0 pypi [conda] torchdata 0.7.1 pypi_0 pypi [conda] torchvision 0.16.0+cu118 pypi_0 pypi