pytorch / extension-cpp

C++ extensions in PyTorch
1.02k stars 214 forks source link

Type error #21

Closed maxjiang93 closed 6 years ago

maxjiang93 commented 6 years ago

I failed to compile the cuda code: python setup.py install and I'm rather surprised that this issues has not been brought up before. Here's the error message:

/usr/local/cuda/bin/nvcc -I/home/maxjiang/software/anaconda3/lib/python3.6/site-packages/torch/lib/include -I/home/maxjiang/software/anaconda3/lib/python3.6/site-packages/torch/lib/include/TH -I/home/maxjiang/software/anaconda3/lib/python3.6/site-packages/torch/lib/include/THC -I/usr/local/cuda/include -I/home/maxjiang/software/anaconda3/include/python3.6m -c lltm_cuda_kernel.cu -o build/temp.linux-x86_64-3.6/lltm_cuda_kernel.o -DTORCH_EXTENSION_NAME=lltm_cuda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options '-fPIC' -std=c++11
lltm_cuda_kernel.cu(54): error: calling a __host__ function("std::fmax<double, float> ") from a __global__ function("_NV_ANON_NAMESPACE::lltm_cuda_forward_kernel<float> ") is not allowed

lltm_cuda_kernel.cu(54): error: identifier "std::fmax<double, float> " is undefined in device code

2 errors detected in the compilation of "/tmp/tmpxft_00003be3_00000000-6_lltm_cuda_kernel.cpp1.ii".

Most of this is probably irrelevant except for gcc version:

Here's my hacky fix that worked, by simply wrapping scalar_t around the doubles. Not sure this is the most elegant solution.

lltm_cuda_kernel.cu lines 26-29:

template <typename scalar_t>
__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) {
  return fmax(scalar_t(0.0), z) + fmin(scalar_t(0.0), alpha * (exp(z) - scalar_t(1.0)));
}
goldsborough commented 6 years ago

It has been mentioned in https://github.com/pytorch/extension-cpp/issues/14. Let me know if those fixes work for you

maxjiang93 commented 6 years ago

Oh sorry I didn't notice #14. Though I'm not sure what the suggested fix was in #14 is. It seems that OP's suggested fix was to recompile Pytorch from scratch? I didn't go through recompiling Pytorch from scratch, my fix above by specifically type casting them to be scalar_t seems easier :)