Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.81k stars 1.07k forks source link

HausdorffDTLoss leads to GPU memory leak. #7480

Open dimka11 opened 8 months ago

dimka11 commented 8 months ago

Describe the bug Using this loss method with Trainer from transformers library (Pytorch) and YOLOv8 (Pytorch) leads to crash training shortly after start due to cuda out of memory. 16 gb gpu memory, batch size is 1 with 128*128 image. Training crash after ~ 100 iterations.

Environment

Kaggle Notebook, python 3.10.12, last monai version from pip.

Also reproduced this bug under Windows 11 with code from example:

%%time

import torch
import numpy as np
from monai.losses.hausdorff_loss import HausdorffDTLoss
from monai.networks.utils import one_hot

for i in range(0, 30):
    B, C, H, W = 16, 5, 512, 512
    input = torch.rand(B, C, H, W)
    target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
    target = one_hot(target_idx[:, None, ...], num_classes=C)
    self = HausdorffDTLoss(include_background=True ,reduction='none', softmax=True)
    loss = self(input.to('cuda'), target.to('cuda'))
    assert np.broadcast_shapes(loss.shape, input.shape) == input.shape

It ate about 5 gb memory, on the GPU consumption graph it looks like a flat line with several rises.

SarthakJShetty-path commented 8 months ago

Describe the bug Using this loss method with Trainer from transformers library (Pytorch) and YOLOv8 (Pytorch) leads to crash training shortly after start due to cuda out of memory. 16 gb gpu memory, batch size is 1 with 128*128 image. Training crash after ~ 100 iterations.

Environment

Kaggle Notebook, python 3.10.12, last monai version from pip.

Also reproduced this bug under Windows 11 with code from example:

%%time

import torch
import numpy as np
from monai.losses.hausdorff_loss import HausdorffDTLoss
from monai.networks.utils import one_hot

for i in range(0, 30):
    B, C, H, W = 16, 5, 512, 512
    input = torch.rand(B, C, H, W)
    target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
    target = one_hot(target_idx[:, None, ...], num_classes=C)
    self = HausdorffDTLoss(include_background=True ,reduction='none', softmax=True)
    loss = self(input.to('cuda'), target.to('cuda'))
    assert np.broadcast_shapes(loss.shape, input.shape) == input.shape

It ate about 5 gb memory, on the GPU consumption graph it looks like a flat line with several rises.

+1 to this. I'm also running into the same exact issue, where there appears to be a GPU memory leak in Monai's HausdorffTDLoss.

Things that I tried to alleviate this issue:

  1. Used @torch.no_grad() in metric update function, since I'm using the loss as a metric
  2. Used .detach() for the prediction and target tensors to make sure that they aren't part of the computational graph
KumoLiu commented 8 months ago

Thank you for bringing up this issue. I took the time to delve deeper into the situation using your provided sample code. From my findings, there doesn't appear to be a noticeable GPU memory leak. Please refer to the GPU memory usage graph provided below: image Minor fluctuations present in the graph can be attributed to the initialization of the super loss function as seen here. Thank you again for your diligent reporting. Best,

SarthakJShetty-path commented 8 months ago

Thank you for bringing up this issue. I took the time to delve deeper into the situation using your provided sample code. From my findings, there doesn't appear to be a noticeable GPU memory leak. Please refer to the GPU memory usage graph provided below: image Minor fluctuations present in the graph can be attributed to the initialization of the super loss function as seen here. Thank you again for your diligent reporting. Best,

Thank you for your quick response! I appreciate you taking the time to get back to us.

I modified the code sample above slightly to plot the memory allocated by PyTorch during the process:

import numpy as np
import torch
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss

gpu_consumption = []
steps = []

for i in range(0, 100):
    B, C, H, W = 16, 5, 512, 512
    input = torch.rand(B, C, H, W)
    target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
    target = one_hot(target_idx[:, None, ...], num_classes=C)
    self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
    loss = self(input.to("cuda"), target.to("cuda"))
    assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
    memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
    gpu_consumption.append(memory_consumption)
    steps.append(i)
    print(f"GPU max memory allocated: {memory_consumption} GB")

plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show()

Which generated the following graph:

GPUConsumptionMonai

It appears that the memory consumption continuously increases with the number of steps, eventually leading to a crash because the system runs out of available GPU memory.

Please let me know if the initialization should be handled different. The person above instantiates a new loss in every loop, which might be linked to what you're talking about about super instantiation taking some additional GPU memory, but I've noticed this independently with my training runs as well.

Thanks!

KumoLiu commented 8 months ago

Hi @SarthakJShetty-path, I used the same code you shared, the graph looks like: image

What's your PyTorch and MONAI version?

SarthakJShetty-path commented 8 months ago

Hi @SarthakJShetty-path, I used the same code you shared, the graph looks like: image

What's your PyTorch and MONAI version?

Interesting!

Here are the versions:

(venv) ┌─[pc@home] - [~/projects/]
└─[$] <git:(feature_branch*)> python 
Python 3.8.10 (default, Nov 22 2023, 10:22:35) 
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import monai
>>> torch.__version__
'2.1.2+cu121'
>>> monai.__version__
'1.3.0'
>>> 
KumoLiu commented 8 months ago

I've conducted tests using a new 1.3.0 image and unfortunately, I've been unable to reproduce your reported issue. Could I kindly recommend attempting the same process in a fresh environment on your end? To aid you, please use the Docker command below to pull the appropriate MONAI image: docker pull projectmonai/monai:1.3.0

SarthakJShetty-path commented 8 months ago

I've conducted tests using a new 1.3.0 image and unfortunately, I've been unable to reproduce your reported issue. Could I kindly recommend attempting the same process in a fresh environment on your end? To aid you, please use the Docker command below to pull the appropriate MONAI image: docker pull projectmonai/monai:1.3.0

Sure @KumoLiu I can make a fresh build and recheck this issue in a bit. Just to confirm: You weren't able to replicate the issue even with the torch 2.1.2+cu121 version, or just with monai 1.3.0?

Thank you!

KumoLiu commented 8 months ago

Just to confirm: You weren't able to replicate the issue even with the torch 2.1.2+cu121 version

Yes, with the torch 2.1.2+cu121 version.

dimka11 commented 7 months ago

@SarthakJShetty-path any updates?

SarthakJShetty-path commented 7 months ago

@SarthakJShetty-path any updates?

Sorry about the delay with this. I haven't gotten around to pulling the Docker image and trying, but several members on our team are reporting this issue, with approximately the same CUDA + Torch version. I will pull the Docker image by EoD today and get back to you. Thank you, and sorry for the delay.

dimka11 commented 7 months ago

@SarthakJShetty-path I did switch to https://github.com/Project-MONAI/MONAI/pull/4205 ShapeLoss, but don't sure that it brings the expected results.

SarthakJShetty-path commented 7 months ago

@SarthakJShetty-path I did switch to #4205 ShapeLoss, but don't sure that it brings the expected results.

Can you try running this piece of code and posting the results?

import numpy as np
import torch
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss

gpu_consumption = []
steps = []

for i in range(0, 100):
    B, C, H, W = 16, 5, 512, 512
    input = torch.rand(B, C, H, W)
    target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
    target = one_hot(target_idx[:, None, ...], num_classes=C)
    self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
    loss = self(input.to("cuda"), target.to("cuda"))
    assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
    memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
    gpu_consumption.append(memory_consumption)
    steps.append(i)
    print(f"GPU max memory allocated: {memory_consumption} GB")

plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show()

It looks like @KumoLiu got a much different graph from what I received.

dimka11 commented 7 months ago

@SarthakJShetty-path

` Kaggle Notebook: Python 3.10.13 2.1.2 1.3.0

GPU max memory allocated: 0.575145984 GB GPU max memory allocated: 0.856951808 GB GPU max memory allocated: 1.058018304 GB GPU max memory allocated: 1.546395648 GB GPU max memory allocated: 2.002266112 GB GPU max memory allocated: 2.002266112 GB GPU max memory allocated: 2.002266112 GB GPU max memory allocated: 2.404661248 GB GPU max memory allocated: 2.731032064 GB GPU max memory allocated: 2.731032064 GB GPU max memory allocated: 2.731032064 GB GPU max memory allocated: 2.775597056 GB GPU max memory allocated: 2.775597056 GB GPU max memory allocated: 2.775597056 GB GPU max memory allocated: 3.17851648 GB GPU max memory allocated: 3.546306048 GB GPU max memory allocated: 3.546306048 GB GPU max memory allocated: 3.546306048 GB GPU max memory allocated: 3.587463168 GB GPU max memory allocated: 3.587463168 GB GPU max memory allocated: 3.587463168 GB GPU max memory allocated: 3.990120448 GB GPU max memory allocated: 4.409028096 GB GPU max memory allocated: 4.409028096 GB GPU max memory allocated: 4.409028096 GB GPU max memory allocated: 4.4410112 GB GPU max memory allocated: 4.718360576 GB GPU max memory allocated: 4.718360576 GB GPU max memory allocated: 4.718360576 GB GPU max memory allocated: 4.81351936 GB GPU max memory allocated: 4.81351936 GB GPU max memory allocated: 4.81351936 GB GPU max memory allocated: 5.252877312 GB GPU max memory allocated: 5.533896704 GB GPU max memory allocated: 5.533896704 GB GPU max memory allocated: 5.533896704 GB GPU max memory allocated: 5.67099904 GB GPU max memory allocated: 5.67099904 GB GPU max memory allocated: 5.67099904 GB GPU max memory allocated: 6.073394176 GB GPU max memory allocated: 6.399764992 GB GPU max memory allocated: 6.399764992 GB GPU max memory allocated: 6.399764992 GB GPU max memory allocated: 6.444329984 GB GPU max memory allocated: 6.444329984 GB GPU max memory allocated: 6.444329984 GB GPU max memory allocated: 6.847249408 GB GPU max memory allocated: 7.215038976 GB GPU max memory allocated: 7.215038976 GB GPU max memory allocated: 7.215038976 GB GPU max memory allocated: 7.256196096 GB GPU max memory allocated: 7.256196096 GB GPU max memory allocated: 7.256196096 GB GPU max memory allocated: 7.658853376 GB GPU max memory allocated: 8.077761024 GB GPU max memory allocated: 8.077761024 GB GPU max memory allocated: 8.077761024 GB GPU max memory allocated: 8.109744128 GB GPU max memory allocated: 8.387093504 GB GPU max memory allocated: 8.387093504 GB GPU max memory allocated: 8.387093504 GB GPU max memory allocated: 8.482252288 GB GPU max memory allocated: 8.482252288 GB GPU max memory allocated: 8.482252288 GB GPU max memory allocated: 8.92161024 GB GPU max memory allocated: 9.202629632 GB GPU max memory allocated: 9.202629632 GB GPU max memory allocated: 9.202629632 GB GPU max memory allocated: 9.339731968 GB GPU max memory allocated: 9.339731968 GB GPU max memory allocated: 9.339731968 GB GPU max memory allocated: 9.742127104 GB GPU max memory allocated: 10.06849792 GB GPU max memory allocated: 10.06849792 GB GPU max memory allocated: 10.06849792 GB GPU max memory allocated: 10.113062912 GB GPU max memory allocated: 10.113062912 GB GPU max memory allocated: 10.113062912 GB

Google Colab notebook: Python 3.10.12 2.2.1+cu121 1.3.0

GPU max memory allocated: 0.58143744 GB GPU max memory allocated: 0.954994176 GB GPU max memory allocated: 1.023153152 GB GPU max memory allocated: 1.54456064 GB GPU max memory allocated: 2.002003968 GB GPU max memory allocated: 2.002003968 GB GPU max memory allocated: 2.002003968 GB GPU max memory allocated: 2.401777664 GB GPU max memory allocated: 2.679651328 GB GPU max memory allocated: 2.679651328 GB GPU max memory allocated: 2.679651328 GB GPU max memory allocated: 2.767994368 GB GPU max memory allocated: 2.767994368 GB GPU max memory allocated: 2.76904448 GB GPU max memory allocated: 3.179825664 GB GPU max memory allocated: 3.179825664 GB GPU max memory allocated: 3.179825664 GB GPU max memory allocated: 3.625475072 GB GPU max memory allocated: 3.903086592 GB GPU max memory allocated: 3.903086592 GB GPU max memory allocated: 3.903086592 GB GPU max memory allocated: 3.991429632 GB GPU max memory allocated: 3.991429632 GB GPU max memory allocated: 3.9922176 GB GPU max memory allocated: 4.403260928 GB GPU max memory allocated: 4.403260928 GB GPU max memory allocated: 4.403260928 GB GPU max memory allocated: 4.812209152 GB GPU max memory allocated: 4.812209152 GB GPU max memory allocated: 4.812209152 GB GPU max memory allocated: 5.251828736 GB GPU max memory allocated: 5.532848128 GB GPU max memory allocated: 5.532848128 GB GPU max memory allocated: 5.532848128 GB GPU max memory allocated: 5.58396672 GB GPU max memory allocated: 5.58396672 GB GPU max memory allocated: 5.621192704 GB GPU max memory allocated: 6.032498176 GB GPU max memory allocated: 6.032498176 GB GPU max memory allocated: 6.032498176 GB GPU max memory allocated: 6.47788544 GB GPU max memory allocated: 6.755759104 GB GPU max memory allocated: 6.755759104 GB GPU max memory allocated: 6.755759104 GB GPU max memory allocated: 6.84515072 GB GPU max memory allocated: 6.84515072 GB GPU max memory allocated: 6.84515072 GB GPU max memory allocated: 7.258817536 GB GPU max memory allocated: 7.258817536 GB GPU max memory allocated: 7.258817536 GB GPU max memory allocated: 7.698174976 GB GPU max memory allocated: 7.978145792 GB GPU max memory allocated: 7.978145792 GB GPU max memory allocated: 7.978145792 GB GPU max memory allocated: 8.06701312 GB GPU max memory allocated: 8.06701312 GB GPU max memory allocated: 8.06701312 GB GPU max memory allocated: 8.523933696 GB GPU max memory allocated: 8.523933696 GB GPU max memory allocated: 8.523933696 GB GPU max memory allocated: 8.924755968 GB GPU max memory allocated: 9.201581056 GB GPU max memory allocated: 9.201581056 GB GPU max memory allocated: 9.201581056 GB GPU max memory allocated: 9.25217536 GB GPU max memory allocated: 9.25217536 GB GPU max memory allocated: 9.289401344 GB GPU max memory allocated: 9.701231104 GB GPU max memory allocated: 9.701231104 GB GPU max memory allocated: 9.701231104 GB GPU max memory allocated: 10.187250688 GB GPU max memory allocated: 10.187250688 GB GPU max memory allocated: 10.187250688 GB GPU max memory allocated: 10.187250688 GB GPU max memory allocated: 10.475086336 GB GPU max memory allocated: 10.475086336 GB GPU max memory allocated: 10.475086336 GB GPU max memory allocated: 10.516767232 GB GPU max memory allocated: 10.516767232 GB GPU max memory allocated: 10.516767232 GB GPU max memory allocated: 10.92440576 GB

Local (Windows 11, RTX3060)

# 2.1.1+cu121 1.3.0 3.10.11 (tags/v3.10.11:7d4cc5a, Apr 5 2023, 00:38:17) [MSC v.1929 64 bit (AMD64)] GPU max memory allocated: 0.542377984 GB GPU max memory allocated: 0.990908416 GB GPU max memory allocated: 0.990908416 GB GPU max memory allocated: 1.466966016 GB GPU max memory allocated: 1.95534336 GB GPU max memory allocated: 2.323132928 GB GPU max memory allocated: 2.323132928 GB GPU max memory allocated: 2.323132928 GB GPU max memory allocated: 2.364552192 GB GPU max memory allocated: 2.364552192 GB GPU max memory allocated: 2.364552192 GB GPU max memory allocated: 2.767209472 GB GPU max memory allocated: 3.18611712 GB GPU max memory allocated: 3.18611712 GB GPU max memory allocated: 3.18611712 GB GPU max memory allocated: 3.218100224 GB GPU max memory allocated: 3.495187456 GB GPU max memory allocated: 3.495187456 GB GPU max memory allocated: 3.495187456 GB GPU max memory allocated: 3.591132672 GB GPU max memory allocated: 3.591132672 GB GPU max memory allocated: 3.591132672 GB GPU max memory allocated: 4.029966336 GB GPU max memory allocated: 4.310723584 GB GPU max memory allocated: 4.310723584 GB GPU max memory allocated: 4.310723584 GB GPU max memory allocated: 4.40273664 GB GPU max memory allocated: 4.40273664 GB GPU max memory allocated: 4.40273664 GB GPU max memory allocated: 4.850745344 GB GPU max memory allocated: 5.176067584 GB GPU max memory allocated: 5.176067584 GB GPU max memory allocated: 5.176067584 GB GPU max memory allocated: 5.222467584 GB GPU max memory allocated: 5.222467584 GB GPU max memory allocated: 5.222467584 GB GPU max memory allocated: 5.624338432 GB GPU max memory allocated: 5.991603712 GB GPU max memory allocated: 5.991603712 GB GPU max memory allocated: 5.991603712 GB GPU max memory allocated: 5.99658496 GB GPU max memory allocated: 5.99658496 GB GPU max memory allocated: 5.99658496 GB GPU max memory allocated: 6.398455808 GB GPU max memory allocated: 6.768604672 GB GPU max memory allocated: 6.768604672 GB GPU max memory allocated: 6.768604672 GB GPU max memory allocated: 6.887881728 GB GPU max memory allocated: 7.16365824 GB GPU max memory allocated: 7.16365824 GB GPU max memory allocated: 7.16365824 GB GPU max memory allocated: 7.260389888 GB `

SarthakJShetty-path commented 7 months ago

@KumoLiu Looks like even @dimka11 has the same error. The GPU memory seems to strictly increase even on Google Colab, with values very similar to what I posted above.

johnzielke commented 7 months ago

Chiming in here, since I worked on https://github.com/Project-MONAI/MONAI/pull/7008 trying to make the HausdorffLoss work with cucim.

@SarthakJShetty-path I did switch to #4205 ShapeLoss, but don't sure that it brings the expected results.

Can you try running this piece of code and posting the results?

import numpy as np
import torch
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss

gpu_consumption = []
steps = []

for i in range(0, 100):
    B, C, H, W = 16, 5, 512, 512
    input = torch.rand(B, C, H, W)
    target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
    target = one_hot(target_idx[:, None, ...], num_classes=C)
    self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
    loss = self(input.to("cuda"), target.to("cuda"))
    assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
    memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
    gpu_consumption.append(memory_consumption)
    steps.append(i)
    print(f"GPU max memory allocated: {memory_consumption} GB")

plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show()

It looks like @KumoLiu got a much different graph from what I received.

I just tested this myself, and I also get an increase in GPU memory usage with each step when running your script. (Windows 11, WSL2, Monai 1.3.0, Pytorch 2.2.1+cuda12.1, Python 3.11.8) If I add a gc.collect() after the assert though, the memory usage stays constant, i.e. the script would be

import gc

import numpy as np
import torch
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss

gpu_consumption = []
steps = []

for i in range(0, 10):
    B, C, H, W = 16, 5, 512, 512
    input = torch.rand(B, C, H, W)
    target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
    target = one_hot(target_idx[:, None, ...], num_classes=C)
    self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
    loss = self(input.to("cuda"), target.to("cuda"))
    assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
    gc.collect()
    memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
    gpu_consumption.append(memory_consumption)
    steps.append(i)
    print(f"GPU max memory allocated: {memory_consumption} GB")

plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show()

This seems to indicate some problem with recognizing unused tensors. Maybe there is some issue with cupy/cucim interoperability and weakrefs created by that? I tried to debug this issue using the pytorch profiler, unfortunately as soon as you enable stack traces, there seems to be a bug with pytorch / kineto profiling that writes bad json traces when enabling stack traces. For reference, this is what I used to test:

import gc

import numpy as np
import matplotlib.pyplot as plt
from monai.networks.utils import one_hot
from monai.losses.hausdorff_loss import HausdorffDTLoss
from torch.profiler import profile, ProfilerActivity
import torch
import torch.nn
import torch.optim
import torch.profiler
import torch.utils.data
gpu_consumption = []
steps = []
def calculate():
    B, C, H, W = 16, 5, 512, 512
    input = torch.rand(B, C, H, W)
    target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
    target = one_hot(target_idx[:, None, ...], num_classes=C)
    self = HausdorffDTLoss(include_background=True, reduction="none", softmax=True)
    loss = self(input.to("cuda"), target.to("cuda"))
    assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    # Enabling with stack creates bad traces
    # with_stack=True,
    profile_memory=True,
    record_shapes=True,
    # on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/memleak'),
) as prof:
    for i in range(0, 50):
        prof.step()
        calculate()
        # Adding this line fixes the memory leak
        # gc.collect()
        memory_consumption = torch.cuda.max_memory_allocated(device=None) / (1e9)
        gpu_consumption.append(memory_consumption)
        steps.append(i)
        print(f"GPU max memory allocated: {memory_consumption} GB")
prof.export_chrome_trace("trace.json")
plt.plot(steps, gpu_consumption)
plt.title("GPU consumption (in GB) vs. Steps")
plt.show()

When running it without stack_traces, the memory view looks like this without the gc.collect() call: image

dimka11 commented 7 months ago

Thanks! I have tried with cucim on Kaggle Notebook (installed via !pip install cucim-cu12) and there is no gpu memory leak (without gc.collect()). Of course it's not option for native Windows.

But HausdorffLoss already use distance_transform_edt that use cucim?

johnzielke commented 7 months ago

It does use cucim based on this logic, so both cucim and cupy have to be installed:

    distance_transform_edt, has_cucim = optional_import(
        "cucim.core.operations.morphology", name="distance_transform_edt"
    )
    use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda"

Sorry I'm not quite following. So before that, you ran it without cucim installed?

johnzielke commented 7 months ago

Ok, so I just checked that myself, and it seems you are right. If you have both cucim and cupy installed there does not seem to be a memory leak anymore

johnzielke commented 7 months ago

@KumoLiu I was able to reproduce the issue using the projectmonai/monai:1.3.0 container.

If you run pip uninstall cupy-cuda12x before executing the script, the memory leak will occur.

On the other hand, if you hardcode monai/transforms/utils.py:2112 (use_cp) to False with cupy installed, the memory leak does not occur.

So it seems that cupy being installed changes something somewhere in the memory deallocation or garbage collection. Although I am not sure where this is.

SarthakJShetty-path commented 7 months ago

@KumoLiu I was able to reproduce the issue using the projectmonai/monai:1.3.0 container.

If you run pip uninstall cupy-cuda12x before executing the script, the memory leak will occur.

On the other hand, if you hardcode monai/transforms/utils.py:2112 (use_cp) to False with cupy installed, the memory leak does not occur.

So it seems that cupy being installed changes something somewhere in the memory deallocation or garbage collection. Although I am not sure where this is.

Thank you for looking into this @johnzielke! So this means that installing cupy should a "fix" for now?

johnzielke commented 7 months ago

I mean test it on your setup, but I guess so. You should also get a nice ~10x performance boost in the calculation of the loss with both cupy and cucim since it will run on the GPU then. There should probably be a warning or at least some more docs that explain how to get the calculation to the GPU

KumoLiu commented 7 months ago

Hi @johnzielke, thanks for the detailed report. Your findings are insightful and indeed point to an interaction between CuPy, garbage collection, and memory deallocation which could be the root cause of the memory leak issue. I agree, looking into this could lead both to resolving the memory leak and potentially offering a substantial performance boost by enabling the loss calculation to run on the GPU. And I tried with the 1.3.0 container and the latest container, with cupy-cuda12x installed, it works well. The issue only happens when uninstall cupy-cuda12x. So the issue might be due to the interdependencies between CUDA and the CuPy library. Specifically, when CuPy is installed, it links with particular CUDA libraries to perform GPU computations. When uninstall CuPy, some CUDA operations may not be correctly performed because they need CuPy to access the GPU. Currently I don't have time to take a deep look at this issue.

SarthakJShetty-path commented 7 months ago

Hi @johnzielke, thanks for the detailed report. Your findings are insightful and indeed point to an interaction between CuPy, garbage collection, and memory deallocation which could be the root cause of the memory leak issue. I agree, looking into this could lead both to resolving the memory leak and potentially offering a substantial performance boost by enabling the loss calculation to run on the GPU. And I tried with the 1.3.0 container and the latest container, with cupy-cuda12x installed, it works well. The issue only happens when uninstall cupy-cuda12x. So the issue might be due to the interdependencies between CUDA and the CuPy library. Specifically, when CuPy is installed, it links with particular CUDA libraries to perform GPU computations. When uninstall CuPy, some CUDA operations may not be correctly performed because they need CuPy to access the GPU. Currently I don't have time to take a deep look at this issue.

No worries @KumoLiu thank you (and @johnzielke) for taking a look at this nonetheless. I'll try installing CuPy and double checking that that avoids this GPU leak.