pytorch / extension-cpp

C++ extensions in PyTorch
1.02k stars 214 forks source link

The same codes got different results on different GPU devices. #29

Closed wujiaju closed 5 years ago

wujiaju commented 5 years ago

Hello.

I wrote my test codes as follow: test.py

a = torch.zeros((1), dtype=torch.int)
a = a.cuda(0)
x = test_cuda.func(a)
print(x)

cuda.cpp

#include <torch/torch.h>

void func_wrapper(int* a);
at::Tensor func(at::Tensor a)
{
    func_wrapper(a.data<int>());
    return a;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("func", &func, "func");
}

cuda_kernel.cu

#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>

__global__ void func_kernel(int* __restrict__ a)
{
    a[0] = 4;
}

void func_wrapper(int* a)
{
    func_kernel<<<1,1>>>(a);
}

When I used a = a.cuda(0) in test.py. I got expected result:

tensor([4], device='cuda:0', dtype=torch.int32)

But when I used a = a.cuda(3) (I have multiple GPUs). I got

tensor([0], device='cuda:3', dtype=torch.int32)

The result tensor was tensor([0]). Why?

Thanks a lot.


soumith commented 5 years ago

the func_kernel<<<1,1>>> needs to take a 3rd argument which is the current cuda stream. Also, you need to switch the current device to device 3 in your CUDA code. Otherwise the kernel launches before you finished copying the Tensor to GPU-3, and you'll launch on the wrong GPU.

Look at at::DeviceGuard and at::cuda::getCurrentCUDAStream();