Open raimis opened 2 years ago
Thanks for reporting. I need to look into this. Any help or reference is highly appreciated as I'm pretty unfamiliar with CUDA graphs.
In order to "compile" a piece of code into a CUDA graph one must run it first under what is called "capture mode". Basically a dry run in which CUDA just registers the kernel launches and their arguments. Once the graph is captured, it can be replayed in a more efficient manner. Always working on the same addresses, but the idea is that one rewrites the inputs each time.
This piece of code must comply with a series of restrictions in order to be CUDA-graph compatible, broadly speaking:
In the particular case of the radius kernel there are several things preventing the function to capture. For instance, this suffers from CPU-GPU sync: https://github.com/rusty1s/pytorch_cluster/blob/82e9df944118d7916265ade4fa4f5e1062b1bf48/torch_cluster/radius.py#L55-L61 You could rewrite it as:
batch_size = torch.tensor(1)
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = batch_x.max().to(dtype=torch.int) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = torch.max(batch_size, batch_y.max().to(dtype=torch.int) + 1)
Additionally, this line is not a static control flow: https://github.com/rusty1s/pytorch_cluster/blob/82e9df944118d7916265ade4fa4f5e1062b1bf48/torch_cluster/radius.py#L65 since batch_size is dependent on the contents of the input tensors.
@raimis example does not utilize this code path, since batch_x/y are not being passed. For his example, I believe this line is the culprit: https://github.com/rusty1s/pytorch_cluster/blob/82e9df944118d7916265ade4fa4f5e1062b1bf48/csrc/cuda/radius_cuda.cu#L93 masked_select requires synchronization
Thanks for this insightful issue. Do you have interest to fix the aforementioned problems? Would that be straightforward to integrate?
We use this functionality in several places, so I am eager to help. I am really new to torch though, so my torch-fu is not very good -.- I am trying to make it CUDA-graph compatible without modifying the current interface, so if I get somewhere I will do a PR. I wanted to share some insight in case you saw this as an easy modification.
torch_cluster.radius
is not compatible with the CUDA graph (https://pytorch.org/docs/stable/notes/cuda.html#cuda-graphs)