pytorch / extension-cpp

C++ extensions in PyTorch
1.02k stars 214 forks source link

Runtime error: undefined symbol #23

Closed tczhangzhi closed 6 years ago

tczhangzhi commented 6 years ago

OS: 16.04 PyTorch version: 0.4.1 How you installed PyTorch (conda, pip, source): conda Python version: 3.6 CUDA/cuDNN version: 9.0 GPU models and configuration: Tesla K80

I define my own custom opt:

// sigmoid_cuda_kernal.cu
namespace {
  template <typename scalar_t>
  __device__ __forceinline__ scalar_t sigmoid(scalar_t z) {
    return 1.0 / (1.0 + exp(-z));
  }

  template <typename scalar_t>
  __device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) {
    return (1.0 - z) * z;
  }

  template <typename scalar_t>
  __global__ void sigmoid_cuda_forward_kernel(
      const scalar_t* __restrict__ input,
      scalar_t* __restrict__ output) {
    const int index = blockIdx.x * blockDim.x + blockIdx.y;
    output[index] = sigmoid(input[index]);
  }

  template <typename scalar_t>
  __global__ void sigmoid_cuda_backward_kernel(
      const scalar_t* __restrict__ grad_output,
      const scalar_t* __restrict__ output,
      scalar_t* __restrict__ new_grad_output) {
    const int index = blockIdx.x * blockDim.x + blockIdx.y;
    new_grad_output[index] = d_sigmoid(output[index] * grad_output[index]);
  }
} // namespace

at::Tensor sigmoid_cuda_forward(
    at::Tensor input) {
  auto output = at::zeros_like(input);
  const dim3 blocks(input.size(0), input.size(1));
  const int threads = 1;

  AT_DISPATCH_FLOATING_TYPES(input.type(), "sigmoid_forward_cuda", ([&] {
    sigmoid_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
      input.data<scalar_t>(),
      output.data<scalar_t>());
  }));

  return output;
}

at::Tensor sigmoid_cuda_backward(
    at::Tensor grad_output,
    at::Tensor output) {
  auto new_grad_output = at::zeros_like(grad_output);
  const dim3 blocks(grad_output.size(0), grad_output.size(1));
  const int threads = 1;

  AT_DISPATCH_FLOATING_TYPES(grad_output.type(), "sigmoid_backward_cuda", ([&] {
    sigmoid_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
      grad_output.data<scalar_t>(),
      output.data<scalar_t>(),
      new_grad_output.data<scalar_t>());
  }));

  return new_grad_output;
}

And the cpp wrapper is as follow:

// sigmoid_cuda.cpp
at::Tensor sigmoid_cuda_forward(
    const at::Tensor& input);

at::Tensor sigmoid_cuda_backward(
    const at::Tensor& grad_output,
    const at::Tensor& output);

#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

at::Tensor sigmoid_forward(
    const at::Tensor& input) {
  CHECK_INPUT(input);
  return sigmoid_cuda_forward(input);
}

at::Tensor sigmoid_backward(
    const at::Tensor& grad_output,
    const at::Tensor& output) {
  CHECK_INPUT(grad_output);
  CHECK_INPUT(output);
  return sigmoid_cuda_backward(
    grad_output,
    output);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &sigmoid_forward, "sigmoid forward (CUDA)");
  m.def("backward", &sigmoid_backward, "sigmoid backward (CUDA)");
}

The compilation process is successful.

running install
running bdist_egg
running egg_info
writing sigmoid_cuda_linear_cpp.egg-info/PKG-INFO
writing dependency_links to sigmoid_cuda_linear_cpp.egg-info/dependency_links.txt
writing top-level names to sigmoid_cuda_linear_cpp.egg-info/top_level.txt
reading manifest file 'sigmoid_cuda_linear_cpp.egg-info/SOURCES.txt'
writing manifest file 'sigmoid_cuda_linear_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.6/linear_cpp.cpython-36m-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.6/sigmoid_cuda.cpython-36m-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for sigmoid_cuda.cpython-36m-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/sigmoid_cuda.py to sigmoid_cuda.cpython-36.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying sigmoid_cuda_linear_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying sigmoid_cuda_linear_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying sigmoid_cuda_linear_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying sigmoid_cuda_linear_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.sigmoid_cuda.cpython-36: module references __file__
creating 'dist/sigmoid_cuda_linear_cpp-0.0.0-py3.6-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing sigmoid_cuda_linear_cpp-0.0.0-py3.6-linux-x86_64.egg
removing '/home/zhangzhi/anaconda3/lib/python3.6/site-packages/sigmoid_cuda_linear_cpp-0.0.0-py3.6-linux-x86_64.egg' (and everything under it)
creating /home/zhangzhi/anaconda3/lib/python3.6/site-packages/sigmoid_cuda_linear_cpp-0.0.0-py3.6-linux-x86_64.egg
Extracting sigmoid_cuda_linear_cpp-0.0.0-py3.6-linux-x86_64.egg to /home/zhangzhi/anaconda3/lib/python3.6/site-packages
sigmoid-cuda-linear-cpp 0.0.0 is already the active version in easy-install.pth

Installed /home/zhangzhi/anaconda3/lib/python3.6/site-packages/sigmoid_cuda_linear_cpp-0.0.0-py3.6-linux-x86_64.egg
Processing dependencies for sigmoid-cuda-linear-cpp==0.0.0
Finished processing dependencies for sigmoid-cuda-linear-cpp==0.0.0

But when I import it, things will go wrong.

ImportError: /home/.../anaconda3/lib/python3.6/site-packages/sigmoid_cuda_linear_cpp-0.0.0-py3.6-linux-x86_64.egg/sigmoid_cuda.cpython-36m-x86_64-linux-gnu.so: undefined symbol: _Z20sigmoid_cuda_forwardRKN2at6TensorE
tczhangzhi commented 6 years ago

I have solved the problem. This seems to be a problem caused by CUDAExtension not recognizing the passing of references in CPP files during compilation. So we just have to pass variants by value.

// before fix
at::Tensor sigmoid_cuda_forward(
    const at::Tensor& input);
// after fix
at::Tensor sigmoid_cuda_forward(
    at::Tensor input);
// other code ...