mobaidoctor / med-ddpm

GNU General Public License v3.0
133 stars 15 forks source link

VRAM usage #12

Closed LRpz closed 7 months ago

LRpz commented 7 months ago

Hi,

I am noticing quite high VRAM usage during training, even with a batchsize of 1.

Would there be any way to reduce it? Using apex seems to help a bit but still exceeds my GPU limit (RTX A4000 16GB VRAM).

Thank you!

mobaidoctor commented 7 months ago

Yes, the reason is that images are represented as voxels, which are higher-dimensional than 2D images. The current training setup is designed for images of size 128x128x128. To fit your hardware capacity, you can adjust the image size to smaller dimensions like 128x128x96, 96x96x96, or 64x64x64, depending on what works best for you. Additionally, you can apply optimization techniques to improve performance and memory efficiency; some suggestions can be found at https://towardsdatascience.com/optimize-pytorch-performance-for-speed-and-memory-efficiency-2022-84f453916ea6. However, these optimizations might have limited impact due to the high-dimensional nature of the images. The most effective solution is to scale down your image size to match your hardware's capabilities.