rusty1s / pytorch_cluster

PyTorch Extension Library of Optimized Graph Cluster Algorithms
MIT License
824 stars 147 forks source link

`torch_cluster.radius` is not compatible with the CUDA graph #123

Open raimis opened 2 years ago

raimis commented 2 years ago

torch_cluster.radius is not compatible with the CUDA graph (https://pytorch.org/docs/stable/notes/cuda.html#cuda-graphs)

from torch_cluster import radius

device = pt.device('cuda')
x = pt.tensor([0.0], device=device)
y = pt.tensor([1.0], device=device)

graph = pt.cuda.CUDAGraph()
with pt.cuda.graph(graph):
    radius(x, y, r=2)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [26], in <module>
      8 with pt.cuda.graph(graph):
----> 9     radius(x, y, r=2)

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/shared2/raimis/opt/miniconda/envs/torchmd-net/lib/python3.9/site-packages/torch_cluster/radius.py", line 72, in radius
        torch.cumsum(deg, 0, out=ptr_y[1:])

    return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                                          max_num_neighbors, num_workers)
RuntimeError: CUDA error: operation not permitted when stream is capturing
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
rusty1s commented 2 years ago

Thanks for reporting. I need to look into this. Any help or reference is highly appreciated as I'm pretty unfamiliar with CUDA graphs.

RaulPPelaez commented 1 year ago

In order to "compile" a piece of code into a CUDA graph one must run it first under what is called "capture mode". Basically a dry run in which CUDA just registers the kernel launches and their arguments. Once the graph is captured, it can be replayed in a more efficient manner. Always working on the same addresses, but the idea is that one rewrites the inputs each time.

This piece of code must comply with a series of restrictions in order to be CUDA-graph compatible, broadly speaking:

  1. Static shapes (and more restrictively, kernels must work on the same addresses every time)
  2. Static control flow
  3. No operations that result in CPU-GPU sync More information is available in the link provided by @raimis.

In the particular case of the radius kernel there are several things preventing the function to capture. For instance, this suffers from CPU-GPU sync: https://github.com/rusty1s/pytorch_cluster/blob/82e9df944118d7916265ade4fa4f5e1062b1bf48/torch_cluster/radius.py#L55-L61 You could rewrite it as:

batch_size = torch.tensor(1)
if batch_x is not None:
   assert x.size(0) == batch_x.numel()
   batch_size = batch_x.max().to(dtype=torch.int) + 1
if batch_y is not None:
   assert y.size(0) == batch_y.numel()
   batch_size = torch.max(batch_size, batch_y.max().to(dtype=torch.int) + 1)

Additionally, this line is not a static control flow: https://github.com/rusty1s/pytorch_cluster/blob/82e9df944118d7916265ade4fa4f5e1062b1bf48/torch_cluster/radius.py#L65 since batch_size is dependent on the contents of the input tensors.

@raimis example does not utilize this code path, since batch_x/y are not being passed. For his example, I believe this line is the culprit: https://github.com/rusty1s/pytorch_cluster/blob/82e9df944118d7916265ade4fa4f5e1062b1bf48/csrc/cuda/radius_cuda.cu#L93 masked_select requires synchronization

rusty1s commented 1 year ago

Thanks for this insightful issue. Do you have interest to fix the aforementioned problems? Would that be straightforward to integrate?

RaulPPelaez commented 1 year ago

We use this functionality in several places, so I am eager to help. I am really new to torch though, so my torch-fu is not very good -.- I am trying to make it CUDA-graph compatible without modifying the current interface, so if I get somewhere I will do a PR. I wanted to share some insight in case you saw this as an easy modification.