NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.35k stars 901 forks source link

[BUG] Possible Bug when exporting Kernel as PyTorch extension #1655

Open sycz00 opened 1 month ago

sycz00 commented 1 month ago

I have the following script:

`import random import cutlass import torch

print_module = False

batch = 256 feature_dim_in = 4098 feature_dim_out = 10

type_A = torch.int8 type_B = torch.int8 type_C = torch.int32 type_D = torch.int32

torch.random.manual_seed(12) random.seed(22)

input = torch.randint(-3, 3, size=(batch,feature_dim_in), device='cuda').to(type_A) weight = torch.randint(-3, 3, size=(feature_dim_out,feature_dim_in), device='cuda').to(type_B) weight_col_major = weight.t().contiguous()#.t() bias = torch.randint(-3, 3, size=(batch,feature_dim_out), device='cuda').to(type_C)

result_tensor = torch.zeros_like(bias)

Forward pass:

plan = cutlass.op.Gemm( alpha=1, beta=1, element_A=type_A, element_B=type_B, element_C=type_C, element_D=type_D, layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor ) output_torch = (torch.matmul(input.float(),weight.t().float()) + bias.float()).to(torch.int32) plan.opclass = cutlass.OpcodeClass.Simt plan.run(input, weight_col_major, bias, result_tensor, print_module=print_module)

op = plan.construct() linear_layer_gemm = cutlass.emit.pytorch(op, name='linear_layer', cc=plan.cc, sourcedir='linear', jit=True)`

when I call the python script I receive following error Traceback:

`Traceback (most recent call last): File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 2107, in _run_ninja_build subprocess.run( File "/home/miniconda3/envs/cuda11-8/lib/python3.8/subprocess.py", line 516, in run raise CalledProcessError(retcode, process.args, subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "linear_layer.py", line 55, in linear_layer_gemm = cutlass.emit.pytorch(op, name='linear_layer', cc=plan.cc, sourcedir='linear', jit=True) File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 927, in pytorch return _pytorch_gemm(device_op, name, cc, jit, sourcedir) File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 775, in _pytorch_gemm return _jit(name, cc, cpp_file, cuda_file) File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 697, in _jit jitmodule = load( File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1309, in load return _jit_compile( File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1719, in _jit_compile _write_ninja_file_and_build_library( File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1832, in _write_ninja_file_and_build_library _run_ninja_build( File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 2123, in _run_ninja_build raise RuntimeError(message) from e RuntimeError: Error building extension 'linear_layer': [1/3] c++ -MMD -MF linear_layer.o.d -DTORCH_EXTENSION_NAME=linear_layer -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/TH -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/THC -isystem /home/miniconda3/envs/cuda11-8/include -isystem /home/miniconda3/envs/cuda11-8/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/reproai/reprai/extensions/linear/linear_layer.cpp -o linear_layer.o [2/3] /home/miniconda3/envs/cuda11-8/bin/nvcc --generate-dependencies-with-compile --dependency-output linear_layer_kernel.cuda.o.d -DTORCH_EXTENSION_NAME=linear_layer -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/TH -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/THC -isystem /home/miniconda3/envs/cuda11-8/include -isystem /home/miniconda3/envs/cuda11-8/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -DCUDA_NO_HALF_OPERATORS -DCUDA_NO_HALF_CONVERSIONS -DCUDA_NO_BFLOAT16_CONVERSIONS -DCUDA_NO_HALF2_OPERATORS --expt-relaxed-constexpr -gencode=arch=compute_89,code=sm_89 --compiler-options '-fPIC' -std=c++17 -c /home/reproai/reprai/extensions/linear/linear_layer_kernel.cu -o linear_layer_kernel.cuda.o FAILED: linear_layer_kernel.cuda.o /home/miniconda3/envs/cuda11-8/bin/nvcc --generate-dependencies-with-compile --dependency-output linear_layer_kernel.cuda.o.d -DTORCH_EXTENSION_NAME=linear_layer -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include -isystem /home//miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home//miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/TH -isystem /home//miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/THC -isystem /home//miniconda3/envs/cuda11-8/include -isystem /home//miniconda3/envs/cuda11-8/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -DCUDA_NO_HALF_OPERATORS -DCUDA_NO_HALF_CONVERSIONS -DCUDA_NO_BFLOAT16_CONVERSIONS -DCUDA_NO_HALF2_OPERATORS --expt-relaxed-constexpr -gencode=arch=compute_89,code=sm_89 --compiler-options '-fPIC' -std=c++17 -c /home//reproai/reprai/extensions/linear/linear_layer_kernel.cu -o linear_layer_kernel.cuda.o /home//reproai/reprai/extensions/linear/linear_layer_kernel.cu(106): error: namespace "torch" has no member "I32"

1 error detected in the compilation of "/home//reproai/reprai/extensions/linear/linear_layer_kernel.cu". ninja: build stopped: subcommand failed. `

jackkosaian commented 1 month ago

Thanks for reporting this issue. I think it can be resolved by fixing the mapping of CUTLASS Python dtype to PyTorch C++ dtype here so that "I8" and "I32" are "kI8" and "kI32", respectively.

We'll push out a fix for this.

sycz00 commented 1 month ago

Thanks for the fast reply. I really appreciate this :)

sycz00 commented 1 month ago

Hey @jackkosaian when I want to export a convolutional kernel with similar settings using the following code:

` import torch import cutlass

print_module = True

N, H, W, C = [32, 28, 28, 64]

K, R, S = [128, 3, 3]

stride = (2, 2) padding = (1, 1) dilation = (1, 1)

N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)

dtype = torch.int8 type_A = torch.int8 type_B = torch.int8 type_C = torch.int8 type_D = torch.int32

torch.manual_seed(1234)

input = torch.randint(-3, 3, size=(N, C, H, W), dtype=type_A, device="cuda").to(memory_format=torch.channels_last) weight = torch.randint(-3, 3, size=(K, C, R, S), dtype=type_B, device="cuda").to(memory_format=torch.channels_last) tensor_C = torch.randint(-3, 3, size=(N, K, P, Q), dtype=type_C, device="cuda").to(memory_format=torch.channels_last) output = torch.zeros_like(tensor_C).to(type_D)

alpha = 1 beta = 0

plan = cutlass.Conv2dFprop(element=dtype, element_input=type_A, element_weight=type_B, element_C=type_C ,element_output=type_D, element_accumulator=type_D)

op = plan.construct() conv_layer = cutlass.emit.pytorch(op, name='conv_layer', cc=plan.cc, sourcedir='conv', jit=True) `

I obtain the Traceback below: (do you have another idea here)

Traceback (most recent call last):
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 2107, in _run_ninja_build
    subprocess.run(
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "conv_float.py", line 46, in <module>
    conv_layer = cutlass.emit.pytorch(op, name='conv_layer', cc=plan.cc, sourcedir='conv', jit=True)
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 931, in pytorch
    return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 899, in _pytorch_conv2d
    return _jit(name, cc, cpp_file, cuda_file)
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass/emit/pytorch.py", line 697, in _jit
    jitmodule = load(
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1309, in load
    return _jit_compile(
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1719, in _jit_compile
    _write_ninja_file_and_build_library(
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 1832, in _write_ninja_file_and_build_library
    _run_ninja_build(
  File "/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/utils/cpp_extension.py", line 2123, in _run_ninja_build
    raise RuntimeError(message) from e
RuntimeError: Error building extension 'conv_layer': [1/3] c++ -MMD -MF conv_layer.o.d -DTORCH_EXTENSION_NAME=conv_layer -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/TH -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/THC -isystem /home/miniconda3/envs/cuda11-8/include -isystem /home/miniconda3/envs/cuda11-8/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -c /home/reproai/reprai/extensions/conv/conv_layer.cpp -o conv_layer.o 
[2/3] /home/miniconda3/envs/cuda11-8/bin/nvcc --generate-dependencies-with-compile --dependency-output conv_layer_kernel.cuda.o.d -DTORCH_EXTENSION_NAME=conv_layer -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/TH -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/THC -isystem /home/miniconda3/envs/cuda11-8/include -isystem /home/miniconda3/envs/cuda11-8/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_89,code=sm_89 --compiler-options '-fPIC' -std=c++17 -c /home/reproai/reprai/extensions/conv/conv_layer_kernel.cu -o conv_layer_kernel.cuda.o 
FAILED: conv_layer_kernel.cuda.o 
/home/miniconda3/envs/cuda11-8/bin/nvcc --generate-dependencies-with-compile --dependency-output conv_layer_kernel.cuda.o.d -DTORCH_EXTENSION_NAME=conv_layer -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include -I/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/tools/util/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/TH -isystem /home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/torch/include/THC -isystem /home/miniconda3/envs/cuda11-8/include -isystem /home/miniconda3/envs/cuda11-8/include/python3.8 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_89,code=sm_89 --compiler-options '-fPIC' -std=c++17 -c /home/reproai/reprai/extensions/conv/conv_layer_kernel.cu -o conv_layer_kernel.cuda.o 
/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h(348): error: incomplete type is not allowed
          detected during:
            instantiation of "void cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap_, Element_, ScatterD, PermuteDLayout, UseCUDAStore>::load_with_byte_offset(cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap_, Element_, ScatterD, PermuteDLayout, UseCUDAStore>::Fragment &, int64_t) const [with ThreadMap_=cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>, Element_=int32_t, ScatterD=false, PermuteDLayout=cutlass::layout::NoPermute, UseCUDAStore=false]" 
(381): here
            instantiation of "void cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap_, Element_, ScatterD, PermuteDLayout, UseCUDAStore>::load(cutlass::epilogue::threadblock::PredicatedTileIterator<ThreadMap_, Element_, ScatterD, PermuteDLayout, UseCUDAStore>::Fragment &) const [with ThreadMap_=cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>, Element_=int32_t, ScatterD=false, PermuteDLayout=cutlass::layout::NoPermute, UseCUDAStore=false]" 
/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include/cutlass/epilogue/threadblock/epilogue.h(276): here
            instantiation of "void cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::SourceAspectNeeded::load() [with Shape_=cutlass::gemm::GemmShape<256, 128, 64>, WarpMmaOperator_=cutlass::gemm::warp::MmaTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, int8_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<8, 64>, int8_t, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, int32_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 32>, 32, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, int, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>, cutlass::MatrixShape<1, 1>>, 1, false, __nv_bool>, PartitionsK=1, OutputTileIterator_=cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>, int32_t, false, cutlass::layout::NoPermute, false>, AccumulatorFragmentIterator_=cutlass::epilogue::warp::FragmentIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int, cutlass::Array<int, 4, true>, cutlass::layout::RowMajor>, WarpTileIterator_=cutlass::epilogue::warp::TileIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int32_t, cutlass::layout::RowMajor>, SharedLoadIterator_=cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>::CompactedThreadMap, int32_t, 64>, OutputOp_=cutlass::epilogue::thread::LinearCombination<int32_t, 16, int32_t, int32_t, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, int32_t>, Padding_=cutlass::MatrixShape<0, 8>, FragmentsPerPartition=1, IterationsUnroll=1]" 
/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include/cutlass/epilogue/threadblock/epilogue.h(488): here
            instantiation of "void cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::operator()(const cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::OutputOp &, cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::OutputTileIterator, const cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::AccumulatorTile &, SourceAspect) [with Shape_=cutlass::gemm::GemmShape<256, 128, 64>, WarpMmaOperator_=cutlass::gemm::warp::MmaTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, int8_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<8, 64>, int8_t, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, int32_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 32>, 32, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, int, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>, cutlass::MatrixShape<1, 1>>, 1, false, __nv_bool>, PartitionsK=1, OutputTileIterator_=cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>, int32_t, false, cutlass::layout::NoPermute, false>, AccumulatorFragmentIterator_=cutlass::epilogue::warp::FragmentIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int, cutlass::Array<int, 4, true>, cutlass::layout::RowMajor>, WarpTileIterator_=cutlass::epilogue::warp::TileIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int32_t, cutlass::layout::RowMajor>, SharedLoadIterator_=cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>::CompactedThreadMap, int32_t, 64>, OutputOp_=cutlass::epilogue::thread::LinearCombination<int32_t, 16, int32_t, int32_t, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, int32_t>, Padding_=cutlass::MatrixShape<0, 8>, FragmentsPerPartition=1, IterationsUnroll=1, SourceAspect=cutlass::epilogue::threadblock::Epilogue<cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::warp::MmaTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, int8_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<8, 64>, int8_t, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, int32_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 32>, 32, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, int, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>, cutlass::MatrixShape<1, 1>>, 1, false, __nv_bool>, 1, cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>, int32_t, false, cutlass::layout::NoPermute, false>, cutlass::epilogue::warp::FragmentIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int, cutlass::Array<int, 4, true>, cutlass::layout::RowMajor>, cutlass::epilogue::warp::TileIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int32_t, cutlass::layout::RowMajor>, cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>::CompactedThreadMap, int32_t, 64>, cutlass::epilogue::thread::LinearCombination<int32_t, 16, int32_t, int32_t, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, int32_t>, cutlass::MatrixShape<0, 8>, 1, 1>::SourceAspectNeeded]" 
/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include/cutlass/epilogue/threadblock/epilogue.h(408): here
            instantiation of "void cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::operator()(const cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::OutputOp &, cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::OutputTileIterator, const cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::AccumulatorTile &, cutlass::epilogue::threadblock::Epilogue<Shape_, WarpMmaOperator_, PartitionsK, OutputTileIterator_, AccumulatorFragmentIterator_, WarpTileIterator_, SharedLoadIterator_, OutputOp_, Padding_, FragmentsPerPartition, IterationsUnroll>::OutputTileIterator) [with Shape_=cutlass::gemm::GemmShape<256, 128, 64>, WarpMmaOperator_=cutlass::gemm::warp::MmaTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, int8_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<8, 64>, int8_t, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, int32_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 32>, 32, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, int, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>, cutlass::MatrixShape<1, 1>>, 1, false, __nv_bool>, PartitionsK=1, OutputTileIterator_=cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>, int32_t, false, cutlass::layout::NoPermute, false>, AccumulatorFragmentIterator_=cutlass::epilogue::warp::FragmentIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int, cutlass::Array<int, 4, true>, cutlass::layout::RowMajor>, WarpTileIterator_=cutlass::epilogue::warp::TileIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int32_t, cutlass::layout::RowMajor>, SharedLoadIterator_=cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>::CompactedThreadMap, int32_t, 64>, OutputOp_=cutlass::epilogue::thread::LinearCombination<int32_t, 16, int32_t, int32_t, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, int32_t>, Padding_=cutlass::MatrixShape<0, 8>, FragmentsPerPartition=1, IterationsUnroll=1]" 
/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include/cutlass/conv/kernel/implicit_gemm_convolution.h(425): here
            instantiation of "void cutlass::conv::kernel::ImplicitGemmConvolution<Mma_, Epilogue_, ThreadblockSwizzle_, ConvOperator, ConvProblemSize_, GroupMode_>::operator()(const cutlass::conv::kernel::ImplicitGemmConvolution<Mma_, Epilogue_, ThreadblockSwizzle_, ConvOperator, ConvProblemSize_, GroupMode_>::Params &, cutlass::conv::kernel::ImplicitGemmConvolution<Mma_, Epilogue_, ThreadblockSwizzle_, ConvOperator, ConvProblemSize_, GroupMode_>::SharedStorage &) [with Mma_=cutlass::conv::threadblock::ImplicitGemmMultistage<cutlass::gemm::GemmShape<256, 128, 64>, cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic<cutlass::MatrixShape<256, 64>, int8_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 256>, 256, cutlass::PitchLinearShape<4, 8>, 16>, cutlass::AlignedArray<int8_t, 16, 16>, cutlass::conv::GroupMode::kNone>, cutlass::transform::threadblock::RegularTileAccessIterator<cutlass::MatrixShape<256, 64>, int8_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<8, 64>, 0, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 256>, 256, cutlass::PitchLinearShape<4, 8>, 16>, 16>, cutlass::arch::CacheOperation::Always, cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic<cutlass::MatrixShape<64, 128>, int8_t, cutlass::layout::TensorNHWC, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 128>, 256, cutlass::PitchLinearShape<4, 8>, 16>, cutlass::AlignedArray<int8_t, 16, 16>, cutlass::conv::GroupMode::kNone, false>, cutlass::transform::threadblock::RegularTileAccessIterator<cutlass::MatrixShape<64, 128>, int8_t, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, 1, cutlass::transform::PitchLinearWarpRakedThreadMap<cutlass::PitchLinearShape<64, 128>, 256, cutlass::PitchLinearShape<4, 8>, 16>, 16>, cutlass::arch::CacheOperation::Global, cutlass::gemm::threadblock::MmaPolicy<cutlass::gemm::warp::MmaTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, int8_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<8, 64>, int8_t, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, int32_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 32>, 32, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, int, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>, cutlass::MatrixShape<1, 1>>, 1, false, __nv_bool>, cutlass::MatrixShape<0, 0>, cutlass::MatrixShape<0, 0>, 1>, 3, __nv_bool>, Epilogue_=cutlass::epilogue::threadblock::Epilogue<cutlass::gemm::GemmShape<256, 128, 64>, cutlass::gemm::warp::MmaTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, int8_t, cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<8, 64>, int8_t, cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<8, 64>, int32_t, cutlass::layout::RowMajor, cutlass::gemm::warp::MmaTensorOpPolicy<cutlass::arch::Mma<cutlass::gemm::GemmShape<16, 8, 32>, 32, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, int, cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddSaturate>, cutlass::MatrixShape<1, 1>>, 1, false, __nv_bool>, 1, cutlass::epilogue::threadblock::PredicatedTileIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>, int32_t, false, cutlass::layout::NoPermute, false>, cutlass::epilogue::warp::FragmentIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int, cutlass::Array<int, 4, true>, cutlass::layout::RowMajor>, cutlass::epilogue::warp::TileIteratorTensorOp<cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, int32_t, cutlass::layout::RowMajor>, cutlass::epilogue::threadblock::SharedLoadIterator<cutlass::epilogue::threadblock::OutputTileOptimalThreadMap<cutlass::epilogue::threadblock::OutputTileShape<128, 8, 4, 1, 1>, cutlass::epilogue::threadblock::OutputTileShape<1, 8, 1, 1, 8>, 256, 16, 32>::CompactedThreadMap, int32_t, 64>, cutlass::epilogue::thread::LinearCombination<int32_t, 16, int32_t, int32_t, cutlass::epilogue::thread::ScaleType::Default, cutlass::FloatRoundStyle::round_to_nearest, int32_t>, cutlass::MatrixShape<0, 8>, 1, 1>, ThreadblockSwizzle_=cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, ConvOperator=cutlass::conv::Operator::kFprop, ConvProblemSize_=cutlass::conv::Conv2dProblemSize, GroupMode_=cutlass::conv::GroupMode::kNone]" 
/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include/cutlass/device_kernel.h(73): here
            instantiation of "void cutlass::Kernel<Operator>(Operator::Params) [with Operator=Conv2dFpropKernel]" 
/home/miniconda3/envs/cuda11-8/lib/python3.8/site-packages/cutlass_library/source/include/cutlass/conv/device/implicit_gemm_convolution.h(268): here
            instantiation of "cutlass::Status cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_>::initialize(const cutlass::conv::device::ImplicitGemmConvolution<ImplicitGemmKernel_>::Arguments &, void *, cudaStream_t, cutlass::CudaHostAdapter *) [with ImplicitGemmKernel_=Conv2dFpropKernel]" 
/home/reproai/reprai/extensions/conv/conv_layer_kernel.cu(126): here

1 error detected in the compilation of "/home/reproai/reprai/extensions/conv/conv_layer_kernel.cu".
ninja: build stopped: subcommand failed.

Addtionally: I can't find any configuration for C=3 aka. RGB Input images

github-actions[bot] commented 2 weeks ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.