Closed Joney-Yf closed 2 weeks ago
@Joney-Yf it's expected, eval
calls https://github.com/pyg-team/pytorch_geometric/blob/dafbd3013b19737ac1511d16d22d6529786a63c4/torch_geometric/nn/models/tgn.py#L181C9-L181C14
which calls a memory update https://github.com/pyg-team/pytorch_geometric/blob/dafbd3013b19737ac1511d16d22d6529786a63c4/torch_geometric/nn/models/tgn.py#L185 which is expected use more than 40GB of memory given the num_nodes
, memory_dim
, raw_msg_dim
Also in the repro:
raw_msg_dim
arg of IdentityMessage
(32
in your repro) should reflect raw_msg_dim
's arg of TGNMemory
(200
in your repro):
🐛 Describe the bug
I encountered an out-of-memory (OOM) issue during the evaluation phase, whereas the training procedure runs without any problems. I have verified that the OOM issue is caused solely by the eval() function, which should not be the case.
To reproduce the bug more directly, I have prepared the following code:
Additionally, before encountering the OOM bug, I faced another issue with the following error message:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and CPU!
Fortunately, this device assignment problem has been resolved by previously closed issues #7008 and #8926. After resolving the device assignment problem, I ran the above code, and the GPU memory usage exploded from 6GB to more than 40 GB.
However, if I comment on the memory.eval() line, the GPU memory usage remains under 10GB. This is unexpected because model.eval() should not cause such a dramatic increase in GPU memory usage. I believe this is a bug.
Thank you for your assistance.
Versions
Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] torch==1.13.1 [pip3] torch_cluster==1.6.3 [pip3] torch_geometric==2.5.3 [pip3] torch_scatter==2.1.2 [pip3] torch_sparse==0.6.18 [pip3] torch-spline-conv==1.2.2+pt113cu117 [pip3] torchaudio==0.13.1 [pip3] torchvision==0.14.1 [conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py39h5eee18b_1
[conda] mkl_fft 1.3.8 py39h5eee18b_0
[conda] mkl_random 1.2.4 py39hdb19cb5_0
[conda] numpy 1.26.4 py39h5f9d8c6_0
[conda] numpy-base 1.26.4 py39hb5e798b_0
[conda] pytorch 1.13.1 py3.9_cuda11.7_cudnn8.5.0_0 pytorch [conda] pytorch-cuda 11.7 h778d358_5 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] torch-cluster 1.6.3 pypi_0 pypi [conda] torch-geometric 2.5.3 pypi_0 pypi [conda] torch-scatter 2.1.2 pypi_0 pypi [conda] torch-sparse 0.6.18 pypi_0 pypi [conda] torch-spline-conv 1.2.2+pt113cu117 pypi_0 pypi [conda] torchaudio 0.13.1 py39_cu117 pytorch [conda] torchvision 0.14.1 py39_cu117 pytorch