teddykoker / torchsort

Fast, differentiable sorting and ranking in PyTorch
https://pypi.org/project/torchsort/
Apache License 2.0
765 stars 33 forks source link

Incorrect results when running on non-default cuda device #48

Closed JustinSzeto closed 2 years ago

JustinSzeto commented 2 years ago

When running torchsort.soft_rank or torchsort.soft_sort on a tensor that's not on the default cuda device (usually cuda:0), the results are incorrect.

import torch
import torchsort

x = torch.tensor([[9,8]], device="cuda:1")

print(torchsort.soft_rank(x))
# tensor([[9., 8.]], device='cuda:1')

print(torchsort.soft_sort(x))
# tensor([[-2., -1.]], device='cuda:1')

Based on the GPU memory usage, torchsort tries to do something on the default cuda device cuda:0 instead of whichever device the input tensor is on. As a workaround, you need to either change the default cuda device with torch.cuda.set_device or use the context manager torch.cuda.device.

import torch
import torchsort

x = torch.tensor([[9,8]], device="cuda:1")

with torch.cuda.device(x.device):
    print(torchsort.soft_rank(x))
    # tensor([[2., 1.]], device='cuda:1')

    print(torchsort.soft_sort(x))
    # tensor([[8., 9.]], device='cuda:1')
teddykoker commented 2 years ago

Hi Justin, thanks for finding this.

I think it might be necessary to add a device guard to the CUDA kernel. This should be accomplished by adding the:

#include <c10/cuda/CUDAGuard.h>

import and:

const at::cuda::OptionalCUDAGuard device_guard(device_of(y));

to lines 309, 335, and

const at::cuda::OptionalCUDAGuard device_guard(device_of(s));

to lines 360 and 382 of isotonic_cuda.cu

I'll try to push these changes to a new branch this evening - if you could try to run it then it would be great to see if this fixes the issue.

teddykoker commented 2 years ago

Hi @JustinSzeto, at your convenience could you try the following?

git clone git@github.com:teddykoker/torchsort.git
cd torchsort
git checkout non-default-cuda
pip install -e '.[testing]'
pytest tests

I have added an additional test case that should ensure correct results when run on cuda:0, cuda:1, etc. If it fails please let me know and I will try to debug more!

JustinSzeto commented 2 years ago

Hi @teddykoker, it seems to be working now! All the tests have passed and my basic example from before is producing the correct results.

teddykoker commented 2 years ago

Great! I just released torchsort==0.1.9, which includes this fix.