pytorch / extension-cpp

C++ extensions in PyTorch
987 stars 207 forks source link

`TORCH_LIBRARY` and `m.def` Not Working as Documented #92

Open andylizf opened 4 months ago

andylizf commented 4 months ago

I encountered an issue where using TORCH_LIBRARY alone, without the dispatcher API, does not work as expected. According to the PyTorch documentation, the TORCH_LIBRARY macro should create a function that registers custom operators. However, when I follow this approach, I get the following error during runtime:

$ python test/benchmark.py cuda
Traceback (most recent call last):
  File "/home/lizhifei/extension-cpp/test/benchmark.py", line 48, in <module>
    new_h, new_C = LLTM(X, W, b, h, C)
                   ^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/extension_cpp/ops.py", line 11, in lltm
    return LLTMFunction.apply(input, weights, bias, old_h, old_cell)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/extension_cpp/ops.py", line 17, in forward
    outputs = torch.ops.extension_cpp.lltm_forward.default(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lizhifei/miniconda3/envs/extension-cpp/lib/python3.12/site-packages/torch/_ops.py", line 921, in __getattr__
    raise AttributeError(
AttributeError: '_OpNamespace' object has no attribute 'lltm_forward'

Here is a link to my modified repository where this issue can be reproduced: andylizf/extension-cpp.

Could you please help me understand why this is happening and how to resolve it? Thank you.

Environment Information - OS: Windows 11 23H2 22631.3527 - PyTorch version: 2.3.0 - How you installed PyTorch: conda - Python version: 3.12.3 - CUDA/cuDNN version: CUDA 12.1, cuDNN 8.9.2 - GPU models and configuration: NVIDIA GeForce RTX 3090 - Conda Env: ``` # packages in environment at /home/lizhifei/miniconda3/envs/extension-cpp: # # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 2_gnu conda-forge blas 1.0 mkl conda-forge brotli-python 1.1.0 py312h30efb56_1 conda-forge bzip2 1.0.8 hd590300_5 conda-forge ca-certificates 2024.2.2 hbcca054_0 conda-forge certifi 2024.2.2 pyhd8ed1ab_0 conda-forge charset-normalizer 3.3.2 pyhd8ed1ab_0 conda-forge cuda 12.1.0 0 nvidia cuda-cccl 12.1.109 0 nvidia/label/cuda-12.1.1 cuda-command-line-tools 12.1.1 0 nvidia/label/cuda-12.1.1 cuda-compiler 12.1.1 0 nvidia/label/cuda-12.1.1 cuda-cudart 12.1.105 0 nvidia cuda-cudart-dev 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-cudart-static 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-cuobjdump 12.1.111 0 nvidia/label/cuda-12.1.1 cuda-cupti 12.1.105 0 nvidia cuda-cupti-static 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-cuxxfilt 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-demo-suite 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-documentation 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-driver-dev 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-gdb 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-libraries 12.1.0 0 nvidia cuda-libraries-dev 12.1.0 0 nvidia cuda-libraries-static 12.1.1 0 nvidia/label/cuda-12.1.1 cuda-nsight 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-nsight-compute 12.1.1 0 nvidia/label/cuda-12.1.1 cuda-nvcc 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-nvdisasm 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-nvml-dev 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-nvprof 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-nvprune 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-nvrtc 12.1.105 0 nvidia cuda-nvrtc-dev 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-nvrtc-static 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-nvtx 12.1.105 0 nvidia cuda-nvvp 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-opencl 12.4.127 0 nvidia cuda-opencl-dev 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-profiler-api 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-runtime 12.1.0 0 nvidia cuda-sanitizer-api 12.1.105 0 nvidia/label/cuda-12.1.1 cuda-toolkit 12.1.0 0 nvidia cuda-tools 12.1.0 0 nvidia cuda-version 12.4 h3060b56_3 conda-forge cuda-visual-tools 12.1.0 0 nvidia extension-cpp 0.0.1 pypi_0 pypi ffmpeg 4.3 hf484d3e_0 pytorch filelock 3.14.0 pyhd8ed1ab_0 conda-forge freetype 2.12.1 h267a509_2 conda-forge fsspec 2024.3.1 pypi_0 pypi gds-tools 1.6.1.9 0 nvidia/label/cuda-12.1.1 gmp 6.3.0 h59595ed_1 conda-forge gnutls 3.6.13 h85f3911_1 conda-forge icu 73.2 h59595ed_0 conda-forge idna 3.7 pyhd8ed1ab_0 conda-forge intel-openmp 2023.1.0 hdb19cb5_46306 jinja2 3.1.4 pyhd8ed1ab_0 conda-forge jpeg 9e h166bdaf_2 conda-forge lame 3.100 h166bdaf_1003 conda-forge lcms2 2.15 hfd0df8a_0 conda-forge ld_impl_linux-64 2.40 h55db66e_0 conda-forge lerc 4.0.0 h27087fc_0 conda-forge libblas 3.9.0 1_h86c2bf4_netlib conda-forge libcblas 3.9.0 5_h92ddd45_netlib conda-forge libcublas 12.1.0.26 0 nvidia libcublas-dev 12.1.0.26 0 nvidia libcublas-static 12.4.5.8 hd3aeb46_1 conda-forge libcufft 11.0.2.4 0 nvidia libcufft-dev 11.0.2.4 0 nvidia libcufft-static 11.2.1.3 hd3aeb46_1 conda-forge libcufile 1.9.1.3 0 nvidia libcufile-dev 1.6.1.9 0 nvidia/label/cuda-12.1.1 libcufile-static 1.6.1.9 0 nvidia/label/cuda-12.1.1 libcurand 10.3.5.147 0 nvidia libcurand-dev 10.3.2.106 0 nvidia/label/cuda-12.1.1 libcurand-static 10.3.2.106 0 nvidia/label/cuda-12.1.1 libcusolver 11.4.4.55 0 nvidia libcusolver-dev 11.4.4.55 0 nvidia libcusolver-static 11.6.1.9 hd3aeb46_1 conda-forge libcusparse 12.0.2.55 0 nvidia libcusparse-dev 12.0.2.55 0 nvidia libcusparse-static 12.3.1.170 hd3aeb46_1 conda-forge libdeflate 1.17 h0b41bf4_0 conda-forge libexpat 2.6.2 h59595ed_0 conda-forge libffi 3.4.2 h7f98852_5 conda-forge libgcc-ng 13.2.0 h77fa898_7 conda-forge libgfortran-ng 13.2.0 h69a702a_7 conda-forge libgfortran5 13.2.0 hca663fb_7 conda-forge libgomp 13.2.0 h77fa898_7 conda-forge libhwloc 2.10.0 default_h2fb2949_1000 conda-forge libiconv 1.17 hd590300_2 conda-forge libjpeg-turbo 2.0.0 h9bf148f_0 pytorch liblapack 3.9.0 5_h92ddd45_netlib conda-forge libnpp 12.0.2.50 0 nvidia libnpp-dev 12.0.2.50 0 nvidia libnpp-static 12.2.5.30 hd3aeb46_1 conda-forge libnsl 2.0.1 hd590300_0 conda-forge libnvjitlink 12.1.105 0 nvidia libnvjitlink-dev 12.1.105 0 nvidia/label/cuda-12.1.1 libnvjitlink-static 12.4.127 hd3aeb46_1 conda-forge libnvjpeg 12.1.1.14 0 nvidia libnvjpeg-dev 12.1.1.14 0 nvidia libnvjpeg-static 12.3.1.117 ha770c72_1 conda-forge libnvvm-samples 12.1.105 0 nvidia/label/cuda-12.1.1 libpng 1.6.43 h2797004_0 conda-forge libsqlite 3.45.3 h2797004_0 conda-forge libstdcxx-ng 13.2.0 hc0a3c3a_7 conda-forge libtiff 4.5.0 h6adf6a1_2 conda-forge libuuid 2.38.1 h0b41bf4_0 conda-forge libwebp-base 1.4.0 hd590300_0 conda-forge libxcrypt 4.4.36 hd590300_1 conda-forge libxml2 2.12.6 h232c23b_2 conda-forge libzlib 1.2.13 hd590300_5 conda-forge llvm-openmp 15.0.7 h0cdce71_0 conda-forge markupsafe 2.1.5 py312h98912ed_0 conda-forge mkl 2023.1.0 h213fc3f_46344 mpmath 1.3.0 pyhd8ed1ab_0 conda-forge ncurses 6.5 h59595ed_0 conda-forge nettle 3.6 he412f7d_0 conda-forge networkx 3.3 pyhd8ed1ab_1 conda-forge ninja 1.11.1.1 pypi_0 pypi nsight-compute 2023.1.1.4 0 nvidia/label/cuda-12.1.1 numpy 1.26.4 py312heda63a1_0 conda-forge openh264 2.1.1 h780b84a_0 conda-forge openjpeg 2.5.0 hfec8fc6_2 conda-forge openssl 3.3.0 hd590300_0 conda-forge pillow 10.3.0 py312h5eee18b_0 pip 24.0 pyhd8ed1ab_0 conda-forge pysocks 1.7.1 pyha2e5f31_6 conda-forge python 3.12.3 hab00c5b_0_cpython conda-forge python_abi 3.12 4_cp312 conda-forge pytorch 2.3.0 py3.12_cuda12.1_cudnn8.9.2_0 pytorch pytorch-cuda 12.1 ha16c6d3_5 pytorch pytorch-mutex 1.0 cuda pytorch pyyaml 6.0.1 py312h98912ed_1 conda-forge readline 8.2 h8228510_1 conda-forge requests 2.31.0 pyhd8ed1ab_0 conda-forge setuptools 69.5.1 pyhd8ed1ab_0 conda-forge sympy 1.12 pyh04b8f61_3 conda-forge tbb 2021.12.0 h00ab1b0_0 conda-forge tk 8.6.13 noxft_h4845f30_101 conda-forge torchaudio 2.3.0 py312_cu121 pytorch torchvision 0.18.0 py312_cu121 pytorch typing_extensions 4.11.0 pyha770c72_0 conda-forge tzdata 2024a h0c530f3_0 conda-forge urllib3 2.2.1 pyhd8ed1ab_0 conda-forge wheel 0.43.0 pyhd8ed1ab_1 conda-forge xz 5.2.6 h166bdaf_0 conda-forge yaml 0.2.5 h7f98852_2 conda-forge zlib 1.2.13 hd590300_5 conda-forge zstd 1.5.6 ha6fb4c9_0 conda-forge ```
crazyboy9103 commented 3 weeks ago

https://github.com/andylizf/extension-cpp/blob/2d49e184f82ab6e1b61e8dc3abb6c7ede65ca37b/extension_cpp/csrc/lltm.cpp#L9

I think it should be TORCH_LIBRARY(extension_cpp, m), not TORCH_LIBRARY(TORCH_EXTENSION_NAME, m)