vacancy / PreciseRoIPooling

Precise RoI Pooling with coordinate gradient support, proposed in the paper "Acquisition of Localization Confidence for Accurate Object Detection" (https://arxiv.org/abs/1807.11590).
MIT License
770 stars 152 forks source link

prroi_pooling_gpu.c seems not support to pytorch1.11 and pytorch1.12. #75

Open hudfdfdf opened 2 years ago

hudfdfdf commented 2 years ago

Hi prroi_pooling_gpu.c seems not support to pytorch1.11 and pytorch1.12. It has have the following questions: RuntimeError: Error building extension '_prroi_pooling': [1/3] /usr/local/cuda/bin/nvcc -DTORCH_EXTENSION_NAME=_prroi_pooling -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.7/dist-packages/torch/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.7/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.7m -D_GLIBCXX_USE_CXX11_ABI=0 -DCUDA_NO_HALF_OPERATORS -DCUDA_NO_HALF_CONVERSIONS -DCUDA_NO_BFLOAT16_CONVERSIONS -DCUDA_NO_HALF2_OPERATORS --expt-relaxed-constexpr -gencode=arch=compute_60,code=compute_60 -gencode=arch=compute_60,code=sm_60 --compiler-options '-fPIC' -std=c++14 -c /content/dissect/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu_impl.cu -o prroi_pooling_gpu_impl.cuda.o [2/3] c++ -MMD -MF prroi_pooling_gpu.o.d -DTORCH_EXTENSION_NAME=_prroi_pooling -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.7/dist-packages/torch/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.7/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.7m -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /content/dissect/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c -o prroi_pooling_gpu.o FAILED: prroi_pooling_gpu.o c++ -MMD -MF prroi_pooling_gpu.o.d -DTORCH_EXTENSION_NAME=_prroi_pooling -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /usr/local/lib/python3.7/dist-packages/torch/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/lib/python3.7/dist-packages/torch/include/TH -isystem /usr/local/lib/python3.7/dist-packages/torch/include/THC -isystem /usr/local/cuda/include -isystem /usr/include/python3.7m -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /content/dissect/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c -o prroi_pooling_gpu.o /content/dissect/netdissect/upsegmodel/prroi_pool/src/prroi_pooling_gpu.c:17:10: fatal error: THC/THC.h: No such file or directory #include <THC/THC.h> ^~~ compilation terminated. ninja: build stopped: subcommand failed.

Do you have any suggestion to revise it? Or can you take the time to update it? Thanks a lot, a great project. Thanks for open source

DocNotVeryStrange commented 1 year ago

Did you figure it out? I am trying to build using pytorch 1.11, python3.9 and cuda11.3.

rasaford commented 1 year ago

I fixed it by removing the dependency on the <THC/THC.h> header. See https://github.com/vacancy/PreciseRoIPooling/pull/80 for a full explanation.