pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.18k stars 3.64k forks source link

Incompatible types between `Dataset.__getitem__` and `DataLoader.__init__` #8705

Open NiklasKappel opened 9 months ago

NiklasKappel commented 9 months ago

šŸ› Describe the bug

For the following code (extracted from here)

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root="data/TUDataset", name="MUTAG")
train_dataset = dataset[:150]
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

Pyright raises the error Argument of type "Dataset | BaseData" cannot be assigned to parameter "dataset" of type "Dataset | Sequence[BaseData] | DatasetAdapter" in function "__init__". This seems to be because __getitem__ of Dataset returns Union['Dataset', BaseData] here but DataLoader does not accept BaseData objects here.

Edit: Bonus question: Should the pytorch Subset type be explicitly supported, so that one can use random_split without having type checkers complain?

Versions

Collecting environment information... PyTorch version: 2.1.2 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A OS: Fedora Linux 39 (Container Image) (x86_64) GCC version: (GCC) 13.2.1 20231205 (Red Hat 13.2.1-6) Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.38 Python version: 3.11.7 | packaged by conda-forge | (main, Dec 15 2023, 08:38:37) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-6.6.7-200.fc39.x86_64-x86_64-with-glibc2.38 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 43 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 8 On-line CPU(s) list: 0-7 Vendor ID: AuthenticAMD Model name: AMD Ryzen 5 PRO 2400G with Radeon Vega Graphics CPU family: 23 Model: 17 Thread(s) per core: 2 Core(s) per socket: 4 Socket(s): 1 Stepping: 0 Frequency boost: enabled CPU(s) scaling MHz: 55% CPU max MHz: 3600.0000 CPU min MHz: 1600.0000 BogoMIPS: 7186.04 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb hw_pstate ssbd ibpb vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt sha_ni xsaveopt xsavec xgetbv1 clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif overflow_recov succor smca sev sev_es Virtualization: AMD-V L1d cache: 128 KiB (4 instances) L1i cache: 256 KiB (4 instances) L2 cache: 2 MiB (4 instances) L3 cache: 4 MiB (1 instance) NUMA node(s): 1 NUMA node0 CPU(s): 0-7 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; untrained return thunk; SMT vulnerable Vulnerability Spec rstack overflow: Mitigation; Safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Versions of relevant libraries: [pip3] mypy==1.7.1 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.2 [pip3] torch==2.1.2 [pip3] torch-cluster==1.6.3 [pip3] torch_geometric==2.4.0 [pip3] torchaudio==2.1.2 [pip3] torchvision==0.16.2 [conda] blas 1.0 mkl conda-forge [conda] cpuonly 2.0 0 pytorch [conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] libblas 3.9.0 16_linux64_mkl conda-forge [conda] libcblas 3.9.0 16_linux64_mkl conda-forge [conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch [conda] liblapack 3.9.0 16_linux64_mkl conda-forge [conda] mkl 2022.2.1 h84fe81f_16997 conda-forge [conda] numpy 1.26.2 py311h64a7726_0 conda-forge [conda] pyg 2.4.0 py311_torch_2.1.0_cpu pyg [conda] pytorch 2.1.2 py3.11_cpu_0 pytorch [conda] pytorch-cluster 1.6.3 py311_torch_2.1.0_cpu pyg [conda] pytorch-mutex 1.0 cpu pytorch [conda] torchaudio 2.1.2 py311_cpu pytorch [conda] torchvision 0.16.2 py311_cpu pytorch
rusty1s commented 9 months ago

If you want to make your code pyright compatible, you can just do assert isinstance(train_dataset, TUDataset) before passing it to the train_loader. The real fix is to use @overload as part of Dataset.__getitem__. We are working on this as part of making whole PyG mypy compatible.