NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.29k stars 892 forks source link

[BUG] Int8 multiplication with pytorch extension: namespace "torch" has no member "I8 #1573

Open MuhammedHasan opened 3 months ago

MuhammedHasan commented 3 months ago

Describe the bug To compile cutlass code to python torch:I8 is used which does not exist. The int8 is named torch:k8I in pytorch

Steps/Code to reproduce bug

import torch
import cutlass

plan = cutlass.op.Gemm(
    element=cutlass.DataType.s8,
    element_accumulator=cutlass.DataType.s32,
    element_D=cutlass.DataType.s32,
    layout=cutlass.LayoutType.RowMajor)
op = plan.construct()

mod = cutlass.emit.pytorch(op, name='gemm_mod', cc=plan.cc, jit=True)

a = torch.empty((2048, 4096), dtype=torch.int8, device=0)
b = torch.empty((4096, 8192), dtype=torch.int8, device=0)

D_ref = a.float() @ b.float()
D = mod.run(a, b)

print(D)
print(D_ref)
Traceback (most recent call last):
  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 2107, in _run_ninja_build
    subprocess.run(
  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "benchmark/main.py", line 22, in <module>

  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 927, in pytorch
    return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 775, in _pytorch_gemm
    return _jit(name, cc, cpp_file, cuda_file)
  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 697, in _jit
    jitmodule = load(
  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1309, in load
    return _jit_compile(
  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1719, in _jit_compile
    _write_ninja_file_and_build_library(
  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1832, in _write_ninja_file_and_build_library
    _run_ninja_build(
  File "/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 2123, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error building extension 'gemm_mod': [1/3] c++ -MMD -MF gemm_mod.o.d -DTORCH_EXTENSION_NAME=gemm_mod -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass_library/source/include -I/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/TH -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/THC -isystem /home/mcelik/anaconda3/envs/cublas8bit/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/mcelik/Projects/torch_cublas_matmul_int8/gemm_mod.cpp -o gemm_mod.o 
[2/3] /home/mcelik/anaconda3/envs/cublas8bit/bin/nvcc --generate-dependencies-with-compile --dependency-output gemm_mod_kernel.cuda.o.d -DTORCH_EXTENSION_NAME=gemm_mod -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass_library/source/include -I/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/TH -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/THC -isystem /home/mcelik/anaconda3/envs/cublas8bit/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_89,code=sm_89 --compiler-options '-fPIC' -std=c++17 -c /home/mcelik/Projects/torch_cublas_matmul_int8/gemm_mod_kernel.cu -o gemm_mod_kernel.cuda.o 
FAILED: gemm_mod_kernel.cuda.o 
/home/mcelik/anaconda3/envs/cublas8bit/bin/nvcc --generate-dependencies-with-compile --dependency-output gemm_mod_kernel.cuda.o.d -DTORCH_EXTENSION_NAME=gemm_mod -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass_library/source/include -I/home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/TH -isystem /home/mcelik/anaconda3/envs/cublas8bit/lib/python3.8/site-packages/torch/include/THC -isystem /home/mcelik/anaconda3/envs/cublas8bit/include -isystem /home/mcelik/anaconda3/envs/cublas8bit/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_89,code=sm_89 --compiler-options '-fPIC' -std=c++17 -c /home/mcelik/Projects/torch_cublas_matmul_int8/gemm_mod_kernel.cu -o gemm_mod_kernel.cuda.o 
/home/mcelik/Projects/torch_cublas_matmul_int8/gemm_mod_kernel.cu(106): error: namespace "torch" has no member "I8"
      at::Tensor D = B.new_empty({M, N}, torch::I8);
                                                ^

1 error detected in the compilation of "/home/mcelik/Projects/torch_cublas_matmul_int8/gemm_mod_kernel.cu".
ninja: build stopped: subcommand failed.

Expected behavior Multiplication performed as expected.

Environment details (please complete the following information): CUDA A6000 GPU Pytorch 2.3.1

github-actions[bot] commented 2 months ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.