pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.66k stars 3.58k forks source link

loss.backward() is too slow #9206

Open S-atan opened 3 months ago

S-atan commented 3 months ago

πŸ› Describe the bug

Hello, I replaced the ordinary convolution operation with a graph convolution operation for the skeleton joint point data. The graph convolution operation takes up to 20 seconds in the loss.backward() step, while the ordinary convolution operation only requires 0.2 seconds. Is there any method to accelerate this process?

Versions

PyTorch version: 2.2.1+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 δΈ“δΈšη‰ˆ GCC version: (i686-posix-sjlj, built by strawberryperl.com project) 4.9.2 Clang version: Could not collect CMake version: version 3.26.4 Libc version: N/A

Python version: 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 17:28:07) [MSC v.1916 64 bit (AMD64)] (64-bit runtime) Python platform: Windows-10-10.0.19045-SP0 Is CUDA available: True CUDA runtime version: 10.1.243 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 Ti

Nvidia driver version: 536.23 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Revision=21764

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] torch==2.2.1+cu121 [pip3] torch_cluster==1.6.3+pt22cu121 [pip3] torch_geometric==2.5.2 [pip3] torch_scatter==2.1.2+pt22cu121 [pip3] torch_sparse==0.6.18+pt22cu121 [pip3] torch_spline_conv==1.2.2+pt22cu121 [pip3] torchaudio==2.2.1+cu121 [pip3] torchvision==0.17.1+cu121 [conda] numpy 1.26.4 pypi_0 pypi [conda] torch 2.2.1+cu121 pypi_0 pypi [conda] torch-cluster 1.6.3+pt22cu121 pypi_0 pypi [conda] torch-geometric 2.5.2 pypi_0 pypi [conda] torch-scatter 2.1.2+pt22cu121 pypi_0 pypi [conda] torch-sparse 0.6.18+pt22cu121 pypi_0 pypi [conda] torch-spline-conv 1.2.2+pt22cu121 pypi_0 pypi [conda] torchaudio 2.2.1+cu121 pypi_0 pypi [conda] torchvision 0.17.1+cu121 pypi_0 pypi

rusty1s commented 3 months ago

Do you mind sharing a small code snippet to understand this better? If your input is a regular grid, there is no real point in using a GNN.