rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.54k stars 921 forks source link

Memory leak if there is a circular reference #152

Closed mfkasim1 closed 3 years ago

mfkasim1 commented 3 years ago

I have found a memory leak if there si a circular reference (i.e. using odeint_adjoint inside a torch.nn.Module where the ode adjoint takes the module as an input). Here's a minimal example:

import torch
from torchdiffeq import odeint_adjoint as odeint

class SimpleFunction(torch.nn.Module):
    def __init__(self, a):
        super().__init__()
        self.a = torch.nn.Parameter(a)
        x0 = torch.ones_like(a)
        t = torch.linspace(0, 1, 1000, dtype=a.dtype, device=a.device)
        xt = odeint(self, x0, t)
        self.xt = xt  # NO_MEMLEAK_IF: this line is removed

    def forward(self, t, x):
        return -self.a * x

def test_fcn():
    a = torch.ones((300000,), dtype=torch.double, device=torch.device("cuda"))
    model = SimpleFunction(a)

for i in range(5):
    test_fcn()
    torch.cuda.empty_cache()
    print("memory allocated:", float(torch.cuda.memory_allocated() / (1024 ** 2)), "MiB")

where it produces:

memory allocated: 2291.115234375 MiB
memory allocated: 4582.23046875 MiB
memory allocated: 6873.345703125 MiB
memory allocated: 9164.4609375 MiB
Traceback (most recent call last):
  File "tdiffeqmemtest.py", line 21, in <module>
    test_fcn()
  File "tdiffeqmemtest.py", line 18, in test_fcn
    model = SimpleFunction(a)
  File "tdiffeqmemtest.py", line 10, in __init__
    xt = odeint(self, x0, t)
  File "/home/muhammadkasim/anaconda2/envs/torchdev/lib/python3.8/site-packages/torchdiffeq/_impl/adjoint.py", line 198, in odeint_adjoint
    ans = OdeintAdjointMethod.apply(shapes, func, y0, t, rtol, atol, method, options, event_fn, adjoint_rtol, adjoint_atol,
  File "/home/muhammadkasim/anaconda2/envs/torchdev/lib/python3.8/site-packages/torchdiffeq/_impl/adjoint.py", line 25, in forward
    ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options, event_fn=event_fn)
  File "/home/muhammadkasim/anaconda2/envs/torchdev/lib/python3.8/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
    solution = solver.integrate(t)
  File "/home/muhammadkasim/anaconda2/envs/torchdev/lib/python3.8/site-packages/torchdiffeq/_impl/solvers.py", line 30, in integrate
    solution[i] = self._advance(t[i])
  File "/home/muhammadkasim/anaconda2/envs/torchdev/lib/python3.8/site-packages/torchdiffeq/_impl/rk_common.py", line 194, in _advance
    self.rk_state = self._adaptive_step(self.rk_state)
  File "/home/muhammadkasim/anaconda2/envs/torchdev/lib/python3.8/site-packages/torchdiffeq/_impl/rk_common.py", line 255, in _adaptive_step
    y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau)
  File "/home/muhammadkasim/anaconda2/envs/torchdev/lib/python3.8/site-packages/torchdiffeq/_impl/rk_common.py", line 65, in _runge_kutta_step
    k = torch.empty(*f0.shape, len(tableau.alpha) + 1, dtype=y0.dtype, device=y0.device)
RuntimeError: CUDA out of memory. Tried to allocate 18.00 MiB (GPU 0; 11.91 GiB total capacity; 11.22 GiB already allocated; 20.25 MiB free; 11.24 GiB reserved in total by PyTorch)

Here's my spec:

I've raised an issue in PyTorch's github because I think this might be related: https://github.com/pytorch/pytorch/issues/52140

rtqichen commented 3 years ago

Hmm yeah, that should ideally be garbage collected, but it does seem like it's ultimately a problem with PyTorch backend. Thanks for bringing this up. I'll wait for the pytorch issue to get resolved to see if there's any action that needs to be taken.