hellopipu / PromptMR

[STACOM@MICCAI 2023] Fill the K-Space and Refine the Image: Prompting for Dynamic and Multi-Contrast MRI Reconstruction (1st@CMRxRecon2023 Challenge)
https://link.springer.com/chapter/10.1007/978-3-031-52448-6_25
MIT License
33 stars 4 forks source link

GPU memory for training #5

Closed ChongWang1024 closed 7 months ago

ChongWang1024 commented 8 months ago

Hi, Thanks for sharing the code of this interesting work.

I am trying to run the training on the fastMRI dataset and I got CUDA out of memory issue even with batch size=1. My GPU is NVIDIA A5000, which has 24G memory.

Could you please tell me how much GPU memory is required to train with batchsize=1?

BTW, I noticed that the memory is gradually increasing for each iteration (batch). Is that normal? Maybe this is somehow related to the code itself and I didn't notice.

Many thanks! looking forward to your reply.

hellopipu commented 8 months ago

Hi @ChongWang1024 ,

Approximately 26 GB of GPU memory is required for training on the FastMRI knee dataset. You can decrease the feature dimension to accommodate your GPU.

I haven't observed any gradual increase in memory usage from my end. Could you provide more details about this issue?

hellopipu commented 8 months ago

Hi @ChongWang1024 ,

Please update the code and then add --low_mem in the training command. This will enable you to use only ~22GB of memory without modifying the model.

hellopipu commented 7 months ago

The potential reason for memory leakage is the pip version of h5py package. You can fix it by conda install h5py or pip install h5py==3.3.

reference: https://github.com/facebookresearch/fastMRI/pull/217 https://github.com/facebookresearch/fastMRI/issues/215

ChongWang1024 commented 7 months ago

The potential reason for memory leakage is the pip version of h5py package. You can fix it by conda install h5py or pip install h5py==3.3.

reference: facebookresearch/fastMRI#217 facebookresearch/fastMRI#215

Hi, Thanks for your detailed reply. I have figured out the problem, it seems to be the wrong version of my pytorch-lightning and h5py.

Many thanks!