pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.54k stars 3.57k forks source link

GPU out of memory caused by eval() mode in TGN #9395

Closed Joney-Yf closed 2 weeks ago

Joney-Yf commented 3 weeks ago

🐛 Describe the bug

I encountered an out-of-memory (OOM) issue during the evaluation phase, whereas the training procedure runs without any problems. I have verified that the OOM issue is caused solely by the eval() function, which should not be the case.

To reproduce the bug more directly, I have prepared the following code:

import os.path as osp

import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
memory_dim = time_dim = embedding_dim = 200
memory = TGNMemory(
    5000000,
    200,
    memory_dim,
    time_dim,
    message_module=IdentityMessage(32, memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

memory.eval()

import time
time.sleep(3600)

Additionally, before encountering the OOM bug, I faced another issue with the following error message: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and CPU!

Fortunately, this device assignment problem has been resolved by previously closed issues #7008 and #8926. After resolving the device assignment problem, I ran the above code, and the GPU memory usage exploded from 6GB to more than 40 GB.

However, if I comment on the memory.eval() line, the GPU memory usage remains under 10GB. This is unexpected because model.eval() should not cause such a dramatic increase in GPU memory usage. I believe this is a bug.

Thank you for your assistance.

Versions

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] torch==1.13.1 [pip3] torch_cluster==1.6.3 [pip3] torch_geometric==2.5.3 [pip3] torch_scatter==2.1.2 [pip3] torch_sparse==0.6.18 [pip3] torch-spline-conv==1.2.2+pt113cu117 [pip3] torchaudio==0.13.1 [pip3] torchvision==0.14.1 [conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch [conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py39h5eee18b_1
[conda] mkl_fft 1.3.8 py39h5eee18b_0
[conda] mkl_random 1.2.4 py39hdb19cb5_0
[conda] numpy 1.26.4 py39h5f9d8c6_0
[conda] numpy-base 1.26.4 py39hb5e798b_0
[conda] pytorch 1.13.1 py3.9_cuda11.7_cudnn8.5.0_0 pytorch [conda] pytorch-cuda 11.7 h778d358_5 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] torch-cluster 1.6.3 pypi_0 pypi [conda] torch-geometric 2.5.3 pypi_0 pypi [conda] torch-scatter 2.1.2 pypi_0 pypi [conda] torch-sparse 0.6.18 pypi_0 pypi [conda] torch-spline-conv 1.2.2+pt113cu117 pypi_0 pypi [conda] torchaudio 0.13.1 py39_cu117 pytorch [conda] torchvision 0.14.1 py39_cu117 pytorch

Kh4L commented 3 weeks ago

@Joney-Yf it's expected, eval calls https://github.com/pyg-team/pytorch_geometric/blob/dafbd3013b19737ac1511d16d22d6529786a63c4/torch_geometric/nn/models/tgn.py#L181C9-L181C14

which calls a memory update https://github.com/pyg-team/pytorch_geometric/blob/dafbd3013b19737ac1511d16d22d6529786a63c4/torch_geometric/nn/models/tgn.py#L185 which is expected use more than 40GB of memory given the num_nodes, memory_dim , raw_msg_dim

Also in the repro:

raw_msg_dim arg of IdentityMessage (32 in your repro) should reflect raw_msg_dim's arg of TGNMemory (200 in your repro):