KidsWithTokens / MedSegDiff

Medical Image Segmentation with Diffusion Model
MIT License
976 stars 145 forks source link

solve DPM-Solver OOM issue #158

Closed lin-tianyu closed 4 months ago

lin-tianyu commented 4 months ago

Problem Description

Many MedSegDiff users have encountered the problem like https://github.com/KidsWithTokens/MedSegDiff/issues/49, https://github.com/KidsWithTokens/MedSegDiff/issues/157. It is about when we use DPM-Solver for sampling, every single sample creates a 2GB GPU memory increase, thus ending up with CUDA Out Of Memory.

Previous Solution

I once solved this problem when I downgraded my PyTorch version to 1.8.1. However, after some untrackable changes in my Python environment, the issue comes up again and Pytorch=1.8.1 won't help.

Problem Solved

After debugging, I realized that the problem is that some Cuda tensors have trouble releasing from GPU memory. Surprisingly, when I added a line of script right after DPM-Solver sampling to force the tensors detachment, the problem was solved.

Since this issue might have troubled a lot of people, I am creating this pull request. Hope it helps.

WuJunde commented 4 months ago

Thank you for your significant contribution, Tianyu. Although I haven't encountered this issue personally, I've observed that it has annoyed many users. It appears that in certain versions of PyTorch, the samples are processed with gradients, which causes the problem. I think a "torch.no_grad()" or "model.eval()" may also resolve this issue. After all, I will proceed to merge this fix and inform users who have experienced this problem to test it. Once again, thank you for your valuable input.