getkeops / keops

KErnel OPerationS, on CPUs and GPUs, with autodiff and without memory overflows
https://www.kernel-operations.io
MIT License
1.03k stars 65 forks source link

Memory leak #284

Open sitzikbs opened 1 year ago

sitzikbs commented 1 year ago

Hi,

Attached below is a minimum working example of using KeOps to match two sets of point clouds using the sinkhorn algorithm. Over time, the GPU memory usage increases until it eventually crashes.

tested on Ubuntu 20.04 with an A5000 GPU.

Is there a way to avoid this leak? What is causing it?


# minimum working example to debug the memory leak
from typing import Union
from scipy.spatial import KDTree
import pykeops.torch as keops
import torch
import tqdm
import numpy as np
import nvidia_smi

def sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2,
             w_x: Union[torch.Tensor, None] = None,
             w_y: Union[torch.Tensor, None] = None,
             eps: float = 1e-3,
             max_iters: int = 100, stop_thresh: float = 1e-5,
             verbose=False):
    """
    Compute the Entropy-Regularized p-Wasserstein Distance between two d-dimensional point clouds
    using the Sinkhorn scaling algorithm. This code will use the GPU if you pass in GPU tensors.
    Note that this algorithm can be backpropped through
    (though this may be slow if using many iterations).
    :param x: A [n, d] tensor representing a d-dimensional point cloud with n points (one per row)
    :param y: A [m, d] tensor representing a d-dimensional point cloud with m points (one per row)
    :param p: Which norm to use. Must be an integer greater than 0.
    :param w_x: A [n,] shaped tensor of optional weights for the points x (None for uniform weights). Note that these must sum to the same value as w_y. Default is None.
    :param w_y: A [m,] shaped tensor of optional weights for the points y (None for uniform weights). Note that these must sum to the same value as w_y. Default is None.
    :param eps: The reciprocal of the sinkhorn entropy regularization parameter.
    :param max_iters: The maximum number of Sinkhorn iterations to perform.
    :param stop_thresh: Stop if the maximum change in the parameters is below this amount
    :param verbose: Print iterations
    :return: a triple (d, corrs_x_to_y, corr_y_to_x) where:
      * d is the approximate p-wasserstein distance between point clouds x and y
      * corrs_x_to_y is a [n,]-shaped tensor where corrs_x_to_y[i] is the index of the approximate correspondence in point cloud y of point x[i] (i.e. x[i] and y[corrs_x_to_y[i]] are a corresponding pair)
      * corrs_y_to_x is a [m,]-shaped tensor where corrs_y_to_x[i] is the index of the approximate correspondence in point cloud x of point y[j] (i.e. y[j] and x[corrs_y_to_x[j]] are a corresponding pair)
    """

    if not isinstance(p, int):
        raise TypeError(f"p must be an integer greater than 0, got {p}")
    if p <= 0:
        raise ValueError(f"p must be an integer greater than 0, got {p}")

    if eps <= 0:
        raise ValueError("Entropy regularization term eps must be > 0")

    if not isinstance(p, int):
        raise TypeError(f"max_iters must be an integer > 0, got {max_iters}")
    if max_iters <= 0:
        raise ValueError(f"max_iters must be an integer > 0, got {max_iters}")

    if not isinstance(stop_thresh, float):
        raise TypeError(f"stop_thresh must be a float, got {stop_thresh}")

    if len(x.shape) != 2:
        raise ValueError(f"x must be an [n, d] tensor but got shape {x.shape}")
    if len(y.shape) != 2:
        raise ValueError(f"x must be an [m, d] tensor but got shape {y.shape}")
    if x.shape[1] != y.shape[1]:
        raise ValueError(f"x and y must match in the last dimension (i.e. x.shape=[n, d], "
                         f"y.shape[m, d]) but got x.shape = {x.shape}, y.shape={y.shape}")

    if w_x is not None:
        if w_y is None:
            raise ValueError("If w_x is not None, w_y must also be not None")
        if len(w_x.shape) > 1:
            w_x = w_x.squeeze()
        if len(w_x.shape) != 1:
            raise ValueError(f"w_x must have shape [n,] or [n, 1] "
                             f"where x.shape = [n, d], but got w_x.shape = {w_x.shape}")
        if w_x.shape[0] != x.shape[0]:
            raise ValueError(f"w_x must match the shape of x in dimension 0 but got "
                             f"x.shape = {x.shape} and w_x.shape = {w_x.shape}")
    if w_y is not None:
        if w_x is None:
            raise ValueError("If w_y is not None, w_x must also be not None")
        if len(w_y.shape) > 1:
            w_y = w_y.squeeze()
        if len(w_y.shape) != 1:
            raise ValueError(f"w_y must have shape [n,] or [n, 1] "
                             f"where x.shape = [n, d], but got w_y.shape = {w_y.shape}")
        if w_x.shape[0] != x.shape[0]:
            raise ValueError(f"w_y must match the shape of y in dimension 0 but got "
                             f"y.shape = {y.shape} and w_y.shape = {w_y.shape}")

    # Distance matrix [n, m]
    # x_i = keops.LazyTensor(x, 0)
    # y_j = keops.LazyTensor(y, 1)
    x_i = keops.Vi(x)  # [n, 1, d]
    y_j = keops.Vj(y)  # [i, m, d]
    if p == 1:
        M_ij = ((x_i - y_j) ** p).abs().sum(dim=2) # [n, m]
    else:
        M_ij = ((x_i - y_j) ** p).sum(dim=2) ** (1.0 / p)  # [n, m]

    # Weights [n,] and [m,]
    if w_x is None and w_y is None:
        w_x = torch.ones(x.shape[0]).to(x) / x.shape[0]
        w_y = torch.ones(y.shape[0]).to(x) / y.shape[0]
        w_y *= (w_x.shape[0] / w_y.shape[0])

    sum_w_x = w_x.sum().item()
    sum_w_y = w_y.sum().item()
    if abs(sum_w_x - sum_w_y) > 1e-5:
        raise ValueError(f"Weights w_x and w_y do not sum to the same value, "
                         f"got w_x.sum() = {sum_w_x} and w_y.sum() = {sum_w_y} "
                         f"(absolute difference = {abs(sum_w_x - sum_w_y)}")

    log_a = torch.log(w_x)  # [n]
    log_b = torch.log(w_y)  # [m]

    # Initialize the iteration with the change of variable
    u = torch.zeros_like(w_x)
    v = eps * torch.log(w_y)

    # u_i = keops.LazyTensor(u.unsqueeze(-1), 0)
    # v_j = keops.LazyTensor(v.unsqueeze(-1), 1)

    u_i = keops.Vi(u.unsqueeze(-1))
    v_j = keops.Vj(v.unsqueeze(-1))

    if verbose:
        pbar = tqdm.trange(max_iters)
    else:
        pbar = range(max_iters)

    for _ in pbar:
        u_prev = u
        v_prev = v

        summand_u = (-M_ij + v_j) / eps
        u = eps * (log_a - summand_u.logsumexp(dim=1).squeeze())
        # u_i = keops.LazyTensor(u.unsqueeze(-1), 0)
        u_i = keops.Vi(u.unsqueeze(-1))

        summand_v = (-M_ij + u_i) / eps
        v = eps * (log_b - summand_v.logsumexp(dim=0).squeeze())
        # v_j = keops.LazyTensor(v.unsqueeze(-1), 1)
        v_j = keops.Vj(v.unsqueeze(-1))

        max_err_u = torch.max(torch.abs(u_prev-u))
        max_err_v = torch.max(torch.abs(v_prev-v))
        if verbose:
            pbar.set_postfix({"Current Max Error": max(max_err_u, max_err_v).item()})
        if max_err_u < stop_thresh and max_err_v < stop_thresh:
            break

    P_ij = ((-M_ij + u_i + v_j) / eps).exp()

    approx_corr_1 = P_ij.argmax(dim=1).squeeze(-1)
    approx_corr_2 = P_ij.argmax(dim=0).squeeze(-1)

    if u.shape[0] > v.shape[0]:
        distance = (P_ij * M_ij).sum(dim=1).sum()
    else:
        distance = (P_ij * M_ij).sum(dim=0).sum()
    return distance, approx_corr_1, approx_corr_2

class NoiseGenerator(torch.utils.data.Dataset):
    def __init__(self, n_points, radius=0.5, n_samples=1, sigma=0.3):
        # # Generate random spherical coordinates

        self.radius = radius
        self.n_points = n_points
        self.points = []
        for i in range(n_samples):
            self.points.append(self.get_noisy_points(sigma))
    def get_noisy_points(self, sigma):
        #TODO add local distortion (use kdtree and fixed motion vec)
        return np.clip(np.random.randn(self.n_points, 3)*sigma, -1, 1).astype(np.float32)

    def __len__(self):
        return len(self.points)

    # This returns given an index the i-th sample and label
    def __getitem__(self, idx):
        return {'points': self.points[idx]}

def local_distort(points, r=0.1, ratio=0.15, sigma=0.05):
    b, n, _ = points.size()
    n_ratio = int(ratio*n)

    points = points.cpu().numpy()
    subset = torch.randperm(n)[:b]
    translation_vec = np.random.rand(b, 3) * sigma

    for i, pts in enumerate(points):
        tree = KDTree(pts)
        _, nn_idx = tree.query(points[i, subset[i], :], k=n_ratio) #distort knn
        points[i, nn_idx, :] += translation_vec[i]

    return torch.tensor(points)

import numba
if __name__ == "__main__":

    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    batch_size = 1
    n_points = 16384
    n_epochs = 10000
    # model = SinkhornCorr(max_iters=1000)
    dataset = NoiseGenerator(n_points, n_samples=10000)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0,
                                                  pin_memory=True)
    for epoch in range(n_epochs):
        for batch_ind, data in enumerate(dataloader):
            print("Epoch {}, Batch {}".format(epoch, batch_ind))
            points1 = data['points']
            points2 = local_distort(points1)
            with torch.no_grad():
                output = sinkhorn(points1.squeeze().cuda(), points2.squeeze().cuda())

            info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
            print("Memory : ({:.2f}% free): {}(total), {} (free), {} (used)".format(100 * info.free / info.total,
                                                                                    info.total,
                                                                                    info.free,
                                                                                    info.used))
joanglaunes commented 1 year ago

Hello @sitzikbs , I just tried your script, and for the few first minutes of running it appears to have no problem, the memory remains constant. But it was still running epoch 0 after two or three minutes during my try, and you have set n_epochs to 10000... Could you tell us more precisely how long the script must be run until one can notice the memory leak, and if there could be a simpler setting where the leak appears more quickly ?

sitzikbs commented 1 year ago

Hi @joanglaunes , For me, you can see ~2Mb increase in GPU memory every 500 iterations or so (not epochs). Once the GPU fills up it crashes. Happened to me on multiple machines. The number of epochs is set to be so large to show that at some point it fills up and crashes. In my "non-minimal-working-example-code" there is also a large model training so it crashes much earlier (since there is less GPU memory available for it to use.

jeanfeydy commented 1 year ago

Hi @sitzikbs , @joanglaunes ,

I can confirm that at least on Google Colab, there seems to be a memory leak indeed.

However, the script is fairly large at the moment and I currently don't have time to identify the line that is causing the issue (and thus to understand if the problem is actually related to KeOps): @sitzikbs, if you could reduce it to a minimum "non-working" example, this would be very helpful.

If you are interested in optimal transport solvers, please also note that GeomLoss provides a KeOps-based implementation of the Sinkhorn loop with several critical improvements that result in x100 speed-ups and improved numerical stability.

Best regards, Jean

sitzikbs commented 1 year ago

@jeanfeydy Thanks for your reply. I will make sure to check GeomLoss for compatibility.

To boil it all down, the problem is in the sinkhorn function which uses KeOps. So here is a non-working example below:

def sinkhorn(x: torch.Tensor, y: torch.Tensor, p: float = 2,
             w_x: Union[torch.Tensor, None] = None,
             w_y: Union[torch.Tensor, None] = None,
             eps: float = 1e-3,
             max_iters: int = 100, stop_thresh: float = 1e-5,
             verbose=False):

    # Distance matrix [n, m]
    x_i = keops.Vi(x)  # [n, 1, d]
    y_j = keops.Vj(y)  # [i, m, d]
    if p == 1:
        M_ij = ((x_i - y_j) ** p).abs().sum(dim=2) # [n, m]
    else:
        M_ij = ((x_i - y_j) ** p).sum(dim=2) ** (1.0 / p)  # [n, m]

    # Weights [n,] and [m,]
    if w_x is None and w_y is None:
        w_x = torch.ones(x.shape[0]).to(x) / x.shape[0]
        w_y = torch.ones(y.shape[0]).to(x) / y.shape[0]
        w_y *= (w_x.shape[0] / w_y.shape[0])

    sum_w_x = w_x.sum().item()
    sum_w_y = w_y.sum().item()
    if abs(sum_w_x - sum_w_y) > 1e-5:
        raise ValueError(f"Weights w_x and w_y do not sum to the same value, "
                         f"got w_x.sum() = {sum_w_x} and w_y.sum() = {sum_w_y} "
                         f"(absolute difference = {abs(sum_w_x - sum_w_y)}")

    log_a = torch.log(w_x)  # [n]
    log_b = torch.log(w_y)  # [m]

    # Initialize the iteration with the change of variable
    u = torch.zeros_like(w_x)
    v = eps * torch.log(w_y)

    u_i = keops.Vi(u.unsqueeze(-1))
    v_j = keops.Vj(v.unsqueeze(-1))

    if verbose:
        pbar = tqdm.trange(max_iters)
    else:
        pbar = range(max_iters)

    for _ in pbar:
        u_prev = u
        v_prev = v

        summand_u = (-M_ij + v_j) / eps
        u = eps * (log_a - summand_u.logsumexp(dim=1).squeeze())
        u_i = keops.Vi(u.unsqueeze(-1))

        summand_v = (-M_ij + u_i) / eps
        v = eps * (log_b - summand_v.logsumexp(dim=0).squeeze())
        v_j = keops.Vj(v.unsqueeze(-1))

        max_err_u = torch.max(torch.abs(u_prev-u))
        max_err_v = torch.max(torch.abs(v_prev-v))
        if verbose:
            pbar.set_postfix({"Current Max Error": max(max_err_u, max_err_v).item()})
        if max_err_u < stop_thresh and max_err_v < stop_thresh:
            break

    P_ij = ((-M_ij + u_i + v_j) / eps).exp()

    approx_corr_1 = P_ij.argmax(dim=1).squeeze(-1)
    approx_corr_2 = P_ij.argmax(dim=0).squeeze(-1)

    if u.shape[0] > v.shape[0]:
        distance = (P_ij * M_ij).sum(dim=1).sum()
    else:
        distance = (P_ij * M_ij).sum(dim=0).sum()
    return distance, approx_corr_1, approx_corr_2
JCBrouwer commented 1 year ago

Hi there, I'll chime in that I'm also running into a similar memory leak, maybe the added context can help narrow it down (or I can move it to another issue if it's unrelated).

I'm training a SaShiMi model which relies on the S4 Module which has a key subroutine implemented with PyKeOps.

After about 12 hours of training, the script crashes with the following out-of-memory error. Over the course of those 12 hours ~4GB of VRAM is getting eaten up (starting at 20GB, crashing when over 24GB). The s4.py in the trace below should be almost identical to the one linked above (maybe a couple lines offset due to slightly different imports / docstring).

[KeOps] error: cuMemAlloc((CUdeviceptr * ) & offsets_d, sizeof(int) * nblocks * sizevars) failed with error CUDA_ERROR_OUT_OF_MEMORY
Traceback (most recent call last):
  File "/env/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/env/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/src/trainer.py", line 432, in <module>
  File "/src/trainer.py", line 344, in main
  File "/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/src/networks/timeseries.py", line 144, in forward
  File "/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/src/networks/external/sashimi.py", line 337, in forward
  File "/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/src/networks/external/sashimi.py", line 177, in forward
  File "/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/src/networks/external/s4.py", line 1558, in forward
  File "/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/src/networks/external/s4.py", line 1375, in forward
  File "/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/src/networks/external/s4.py", line 829, in forward
  File "/src/networks/external/s4.py", line 83, in cauchy_conj
  File "/env/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py", line 627, in __call__
    out = GenredAutograd.apply(
  File "/env/lib/python3.10/site-packages/pykeops/torch/generic/generic_red.py", line 117, in forward
    result = myconv.genred_pytorch(
  File "/env/lib/python3.10/site-packages/pykeops/common/keops_io/LoadKeOps.py", line 232, in genred
    self.call_keops(nx, ny)
  File "/env/lib/python3.10/site-packages/pykeops/common/keops_io/LoadKeOps_nvrtc.py", line 42, in call_keops
    self.launch_keops(
RuntimeError: [KeOps] Cuda error.

Let me know if this is useful here or should be moved to another issue. I don't have a simple repro atm, but I suspect just training (forward, backward, and step) the Sashimi model on arbitrary inputs will run into this sooner or later.

Version info:

Ubuntu 20.04.5 LTS

keopscore==2.1.1
pykeops==2.1.1
torch==1.13.0
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
jeanfeydy commented 1 year ago

Hi @sitzikbs , @JCBrouwer ,

Thanks for your reports: I will investigate a bit to narrow it down and keep you updated when I find something. Hopefully, this is just a typo on our side (e.g. I forgot to free a small buffer somewhere in the main CUDA kernel) and not some PyTorch dark magic!

Best regards, Jean

sitzikbs commented 1 year ago

@jeanfeydy Thanks! I have experimented further and can confirm that the memory leak also happens with much simpler code for finding k nearest neighbors as well:

        X_i = Vi(0, x1.shape[-1])
        X_j = Vj(1, x2.shape[-1])
        D_ij = ((X_i - X_j) ** 2).sum(-1)
        KNN_fun = D_ij.Kmin_argKmin(k, dim=1)
        distances, idxs = KNN_fun(x1.contiguous(), x2.contiguous())

here x1, x2 are 3D arrays of size [b, N, 3] where b is batch size and N is the number of points.

sitzikbs commented 1 year ago

Hi @jeanfeydy, Any updates on this?

I benchmark the k nearest neighbors search with keops for 3D point clouds and it is by far the fastest (compared to faiss cpu, gpu, and pytorch geometric). This memory leak makes my runs crash after a while but I use it anyway because the speedup is worth it.

pearl-rabbit commented 1 year ago

My code has also experienced a memory leak. When I view it through the" memory_profiler ", the program will increase in memory when executing the following statement.. But the first time I used memory_profiler, I'm not sure if this is the reason.

D_ij = ((X_i - X_j) ** 2).sum(-1)

bcharlier commented 1 year ago

Hi everyone, the examples above seems to use ranges (in a hidden manner throug batches). I just read this with @joanglaunes at https://pytorch.org/docs/stable/generated/torch.autograd.function.FunctionCtx.save_for_backward.html

All tensors intended to be used in the backward pass should be saved with 
save_for_backward (as opposed to directly on ctx) to prevent incorrect gradients and memory leaks,
 and enable the application of saved tensor hooks. See [torch.autograd.graph.saved_tensors_hooks](https://pytorch.org/docs/stable/autograd.html#torch.autograd.graph.saved_tensors_hooks).

do you think putting https://github.com/getkeops/keops/blob/34a8bef4293e3017075aa4286bec94c6e59cb6fd/pykeops/pykeops/torch/generic/generic_red.py#L97 in https://github.com/getkeops/keops/blob/34a8bef4293e3017075aa4286bec94c6e59cb6fd/pykeops/pykeops/torch/generic/generic_red.py#L122 can help ?

pearl-rabbit commented 1 year ago

I changed the code to the following format, and the memory leak still seems to exist. ctx.save_for_backward(*args, result, *ranges) and args = ctx.saved_tensors[:-7] result = ctx.saved_tensors[-7].detach() ranges=ctx.saved_tensors[-6:]

And I found that when I stopped model training and restarted, I would free up some memory, but it was more than when I initial training. After I delete 'keops 2.1.1' in the '. cache' folder, and then load the model. The extra memory used is freed up.

bcharlier commented 1 year ago

Here is a minimal sample code that reproduces the leaks.

from pykeops.torch import Vi, Vj
import torch
import nvidia_smi

device = "cuda:0"

def kernel(x1: torch.Tensor, x2: torch.Tensor):
    X_i = Vi(0, x1.shape[-1])
    X_j = Vj(1, x2.shape[-1])
    D_ij = ((X_i - X_j) ** 2).sum(-1)
    cont = D_ij.sum(1)
    return cont(x1, x2)

def kernel_torch(x1: torch.Tensor, x2: torch.Tensor):
    D_ij = ((x1[:, None, :] - x2[None, :, :]) ** 2).sum(-1)
    return D_ij.sum(1)

def get_free_mem():
    info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
    print("Memory : ({:.2f}% free): {}(total), {} (free), {} (used), {} (occupied)".format(100 * info.free / info.total,
                                                                                           info.total,
                                                                                           info.free,
                                                                                           info.used,
                                                                                           torch.cuda.mem_get_info(
                                                                                               device=device)[1]))
    return info.free

if __name__ == "__main__":

    nvidia_smi.nvmlInit()
    handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    n_points = 1638 * 7
    n_iter = 100000

    current_free = -1

    for iter in range(n_iter):

            points1 = torch.rand((n_points, 1), device=device)
            points2 = torch.rand((n_points, 1), device=device)
            with torch.no_grad():
                output = kernel(points1, points2)
                # equivalent torch code that do not produce leaks

                # output_torch = kernel_torch(points1, points2)
                # print(torch.allclose(output, output_torch))

            info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
            if info.free != current_free:
                print(f"Iter {iter} : ", end=" ")
                current_free = get_free_mem()

that gives

[KeOps] Generating code for formula Sum_Reduction(Sum((Var(0,1,0)-Var(1,1,1))**2),0) ... OK
Iter 0 :  Memory : (84.29% free): 4294967296(total), 3620405248 (free), 674562048 (used), 4096196608 (occupied)
Iter 4095 :  Memory : (84.25% free): 4294967296(total), 3618308096 (free), 676659200 (used), 4096196608 (occupied)
Iter 8191 :  Memory : (84.20% free): 4294967296(total), 3616210944 (free), 678756352 (used), 4096196608 (occupied)
Iter 12287 :  Memory : (84.15% free): 4294967296(total), 3614113792 (free), 680853504 (used), 4096196608 (occupied)
Iter 16383 :  Memory : (84.10% free): 4294967296(total), 3612016640 (free), 682950656 (used), 4096196608 (occupied)
Iter 20479 :  Memory : (84.05% free): 4294967296(total), 3609919488 (free), 685047808 (used), 4096196608 (occupied)
Iter 24575 :  Memory : (84.00% free): 4294967296(total), 3607822336 (free), 687144960 (used), 4096196608 (occupied)
Iter 28671 :  Memory : (83.95% free): 4294967296(total), 3605725184 (free), 689242112 (used), 4096196608 (occupied)
...

So, if I understand it correctly, we have a leak of (3620405248 − 3618308096) ÷ 4096 = 512 bytes per call ...

NightWinkle commented 1 year ago

Hi @bcharlier,

The following MWE still produces 2MiB leak per iteration on PyKeops 2.1.2 :

import torch

torch.set_grad_enabled(False)
torch.set_default_dtype(torch.float64)

from pykeops.torch import LazyTensor
import pykeops
pykeops.clean_pykeops()
pykeops.test_torch_bindings()

import subprocess
import os
def get_memory():
    visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")

    return list(map(int, subprocess.check_output(
        ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader", "--id=" + ",".join(visible_devices)]
    ).decode("utf-8").split("\n")[:-1]))

N, M = 100000, 100000

support_x = torch.linspace(0., 1., N)[:, None]
support_y = torch.linspace(0., 1., N)[:, None]
labelsx = pykeops.torch.cluster.grid_cluster(support_x, (support_x.max() - support_x.min())/1000.)
support_x, labx = pykeops.torch.cluster.sort_clusters(support_x, labelsx)

labelsy = pykeops.torch.cluster.grid_cluster(support_y, (support_y.max() - support_y.min())/1000.)
support_y, laby = pykeops.torch.cluster.sort_clusters(support_y, labelsy)

ranges_i = pykeops.torch.cluster.cluster_ranges(labx)
ranges_j = pykeops.torch.cluster.cluster_ranges(laby)
centroids_x = pykeops.torch.cluster.cluster_centroids(support_x, labx)
centroids_y = pykeops.torch.cluster.cluster_centroids(support_y, laby)

dists = (centroids_y[None, :, :] - centroids_x[:, None, :]).square().sum(dim=-1)
keep = dists < 0.01
ranges = pykeops.torch.cluster.from_matrix(ranges_i, ranges_j, keep)
mem = get_memory()[0]
for _ in range(100):
    xi = LazyTensor(support_x[:, None, :])
    yj = LazyTensor(support_y[None, :, :])
    plan = (yj - xi).logsumexp(dim=1, ranges=ranges)
    mem_last = mem
    mem = get_memory()[0]
    print(f"Allocated {mem-mem_last} MiB, Total {mem} MiB")

The version without ranges does not produce any leak.

EDIT: as noted in my pull request, this has to do with my tensors not being on GPU, so it is an error on my side.