teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
767 stars 33 forks source link

Regularization CUDA Memory Leak #15

Closed nmichlo closed 3 years ago

nmichlo commented 3 years ago

Computing the soft_rank over a CUDA tensor that requires gradients results in a memory leak, when a regularisation other than l2 is chosen. However, under the same conditions soft_sort seems to work correctly.

import subprocess as sp
import torch
import torchsort

# stored on GPU
# there does not seem to be a memory leak when requires_grad=False
pred = torch.randn(256, 3*64*64, requires_grad=True).cuda()

# eventually our program will run out of memory
for i in range(100000):
    # problematic line, works when regularization="l2"
    torchsort.soft_rank(pred, regularization="l1")
    # check the current GPU free memory
    print(i, ':', sp.check_output("nvidia-smi --query-gpu=memory.free --format=csv,noheader".split()).decode().strip())
teddykoker commented 3 years ago

Thank you for reporting this! I am able to reproduce this behavior exactly, and I am working on debugging this.

I should also note, the only options for regularization are "l2" and "kl". I will change the code to only allow these values as well.

teddykoker commented 3 years ago

I believe I have fixed the issue. If you could verify locally by installing torchsort with:

pip install "torchsort>=0.1.3"

It would be much appreciated. I am no longer encountering the issue on my hardware. If I don't hear back soon I will release the next version regardless. Thank you again for providing a detailed account of the issue and a minimal example! 😄

For future reference (to anyone out there). The leak was caused by storing a tensor directly to the ctx object in a torch.autograd.Function. Using ctx.save_for_backward() instead will properly free the memory when it is no longer needed.

nmichlo commented 3 years ago

I have tested it locally and it is now working on my system. 🎉

Thank you for your prompt response, fixes and explanation!