ShichenLiu / SoftRas

Project page of paper "Soft Rasterizer: A Differentiable Renderer for Image-based 3D Reasoning"
MIT License
1.2k stars 156 forks source link

CUDA11.1+Torch 1.8 #84

Open yuhanyuhang opened 3 years ago

yuhanyuhang commented 3 years ago

I have successfully built Detectron2 under this environment. But when i build SoftRas,i meet this problem. Can you give me some advice? Thanks so much.

1 error detected in the compilation of "D:/ApolloCar3D/各种pip包/SoftRas-master/soft_renderer/cuda/soft_rasterize_cuda_kernel.cu". soft_rasterize_cuda_kernel.cu ninja: build stopped: subcommand failed. Traceback (most recent call last): File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\torch\utils\cpp_extension.py", line 1673, in _run_ninja_build env=env) File "C:\Users\One\anaconda3\envs\detectron2\lib\subprocess.py", line 438, in run output=stdout, stderr=stderr) subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "setup.py", line 38, in cmdclass={'build_ext': BuildExtension} File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\setuptools__init__.py", line 153, in setup return distutils.core.setup(**attrs) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\core.py", line 148, in setup dist.run_commands() File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\dist.py", line 955, in run_commands self.run_command(cmd) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\dist.py", line 974, in run_command cmd_obj.run() File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\setuptools\command\install.py", line 67, in run self.do_egg_install() File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\setuptools\command\install.py", line 109, in do_egg_install self.run_command('bdist_egg') File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\cmd.py", line 313, in run_command self.distribution.run_command(command) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\dist.py", line 974, in run_command cmd_obj.run() File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\setuptools\command\bdist_egg.py", line 164, in run cmd = self.call_command('install_lib', warn_dir=0) File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\setuptools\command\bdist_egg.py", line 150, in call_command self.run_command(cmdname) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\cmd.py", line 313, in run_command self.distribution.run_command(command) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\dist.py", line 974, in run_command cmd_obj.run() File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\setuptools\command\install_lib.py", line 11, in run self.build() File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\command\install_lib.py", line 107, in build self.run_command('build_ext') File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\cmd.py", line 313, in run_command self.distribution.run_command(command) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\dist.py", line 974, in run_command cmd_obj.run() File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\setuptools\command\build_ext.py", line 79, in run _build_ext.run(self) File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\Cython\Distutils\old_build_ext.py", line 186, in run _build_ext.build_ext.run(self) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\command\build_ext.py", line 339, in run self.build_extensions() File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\torch\utils\cpp_extension.py", line 708, in build_extensions build_ext.build_extensions(self) File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\Cython\Distutils\old_build_ext.py", line 195, in build_extensions _build_ext.build_ext.build_extensions(self) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\command\build_ext.py", line 448, in build_extensions self._build_extensions_serial() File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\command\build_ext.py", line 473, in _build_extensions_serial self.build_extension(ext) File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\setuptools\command\build_ext.py", line 196, in build_extension _build_ext.build_extension(self, ext) File "C:\Users\One\anaconda3\envs\detectron2\lib\distutils\command\build_ext.py", line 533, in build_extension depends=ext.depends) File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\torch\utils\cpp_extension.py", line 690, in win_wrap_ninja_compile with_cuda=with_cuda) File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\torch\utils\cpp_extension.py", line 1359, in _write_ninja_file_and_compile_objects error_prefix='Error compiling objects for extension') File "C:\Users\One\anaconda3\envs\detectron2\lib\site-packages\torch\utils\cpp_extension.py", line 1683, in _run_ninja_build raise RuntimeError(message) from e RuntimeError: Error compiling objects for extension

ShichenLiu commented 3 years ago

Hi there, sorry I don't have a CUDA 11 environment. So I really can't test the code for you. Do you mind check this page: https://pytorch.org/docs/stable/cpp_extension.html and this page: https://pytorch.org/tutorials/advanced/cpp_extension.html?highlight=cuda%20extension to see if there is any hope that you can slightly modify the code to make it compilable? Pretty sure it just needs some minor fix.

yuhanyuhang commented 3 years ago

Thank you so much. The page: https://pytorch.org/tutorials/advanced/cpp_extension.html?highlight=cuda%20extension is very useful to me. I found that in the context of cuda11.1+torch1.8,Errors always occur in the SoftRas-master\soft_renderer\cuda\soft_rasterize_cuda_kernel.cu and voxelization_cuda_kernel.cu files.The function where the error occurs is' atomicAdd has been denied '. so I change:

if CUDA_ARCH < 600 and defined(CUDA_ARCH)

static inline device double atomicAdd(double address, double val) {` unsigned long long int address_as_ull = (unsigned long long int)address; unsigned long long int old = address_as_ull, assumed; do { assumed = old; old = atomicCAS(address_as_ull, assumed, double_as_longlong(val + longlong_as_double(assumed))); // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) } while (assumed != old); } while (assumed != old); return __longlong_as_double(old); }

endif

to:

if !defined(CUDA_ARCH) || CUDA_ARCH >= 600

else

static __inline__ __device__ double atomicAdd(double *address, double val) {
 unsigned long long int* address_as_ull = (unsigned long long int*)address;
 unsigned long long int old = *address_as_ull, assumed;
 if (val==0.0)
   return __longlong_as_double(old);
 do {
   assumed = old;
   old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val +__longlong_as_double(assumed)));
 } while (assumed != old);
 return __longlong_as_double(old);

}

endif

You can solve the error of cu file. Although I think these two codes mean the same thing, the only minor difference is in: if (val==0.0) Return __longlong_as_double (old);

Then you install the dependencies and compile as normal.

ShichenLiu commented 3 years ago

Glad to know that you solved the problem! Do you mind creating a pull request? Of course I can merge the fix by myself, its totally up to you.

birdortyedi commented 2 years ago

yuhanyuhang's solution is great! It solved the issue. please consider merging this change.

StevenLzq commented 1 year ago

CUDA11.1+Torch 1.8 it does't work

if !defined(CUDA_ARCH) || CUDA_ARCH >= 600

else

static inline device double atomicAdd(double address, double val) { unsigned long long int address_as_ull = (unsigned long long int)address; unsigned long long int old = address_as_ull, assumed; if (val==0.0) return longlong_as_double(old); do { assumed = old; old = atomicCAS(address_as_ull, assumed, double_as_longlong(val +longlong_as_double(assumed))); } while (assumed != old); return longlong_as_double(old); }

endif

Just5D commented 1 year ago

CUDA11.1+Torch 1.8 it does't work #if !defined(CUDA_ARCH) || CUDA_ARCH >= 600

else static inline device double atomicAdd(double address, double val) { unsigned long long int address_as_ull = (unsigned long long int)address; unsigned long long int old = address_as_ull, assumed; if (val==0.0) return longlong_as_double(old); do { assumed = old; old = atomicCAS(address_as_ull, assumed, double_as_longlong(val +longlong_as_double(assumed))); } while (assumed != old); return longlong_as_double(old); } #endif

it works!thx