NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.29k stars 892 forks source link

[QST] How to use cutlass in tensorrt_llm plugin? #1220

Closed yuanjiechen closed 1 month ago

yuanjiechen commented 9 months ago

What is your question? Hello, thanks for your project. cutlass version: 2.10 device RTX 3090 I want to implement a W4A4 conv quantization in tensorrt_llm by cutlass. Follow the example and document, I write the plugin and cutlass codes. But it has huge amount of internal compile error, the build test on cutlass is passed. I provide the code related to cutlass, can you provide some advise for debug? I use TensorRef in plugin because the pointer provider by tensorrt is device side tensor, I want to construct the object and send it to Arguments. Thanks quantConvKernels.cu.txt quantConvKernels.h.txt err.txt

hwu36 commented 9 months ago

first, on sm80 you can follow this to choose tile shapes: https://github.com/NVIDIA/cutlass/blob/main/test/unit/conv/device/conv2d_fprop_implicit_gemm_s4nhwc_s4nhwc_s32nhwc_tensor_op_s32_sm80.cu#L105-L107

as to your error, do you use nvcc to compile?

yuanjiechen commented 9 months ago

Thanks for your reply! I use the compile script provided by tensorrt_llm, it depend the cutlass 2.10.

hwu36 commented 9 months ago

if it uses make, you can use make VERBOSE=1 to see the actual command line. you error looks like you compiler does not recognize <<< >>> which is used by device kernel launch.

rawnhenry commented 9 months ago

Hi @yuanjiechen - It sounds like you are using the release branch of TRT-LLM. We have already updated TRT-LLM to depend on CUTLASS 3.3 in the main branch here.

I would suggest using the main branch for your development.

yuanjiechen commented 9 months ago

Hi @hwu36, Thanks for your reply, here is output with VERBOSE=1 err.txt And I try to use easy python api to run cutlass kernel, and copy a small tool from trt-llm to compress int4 data to int8 tensor, but looks like cutlass python not supported for int4? or can I find some ways to cheat the compiler, let it know it's true int4 dtype not int8? `type_A = cutlass.DataType.s4 type_B = cutlass.DataType.s4 type_C = cutlass.DataType.s32 type_D = cutlass.DataType.s32

torch_dtype_A = torch.int8 torch_dtype_C = torch.int32

input = torch.ceil( torch.empty(size=(N, C, H, W), dtype=torch_dtype_A, device="cuda")).to(memory_format=torch.channels_last) weight = torch.ceil( torch.empty(size=(K, C, R, S), dtype=torch_dtype_A, device="cuda")).to(memory_format=torch.channels_last) tensor_C = torch.ceil( torch.empty(size=(N, K, P, Q), dtype=torch_dtype_C, device="cuda")).to(memory_format=torch.channels_last)

output = torch.zeros_like(tensor_C)

plan = cutlass.Conv2d(kind="fprop", element_A=cutlass.DataType.s4, element_B=cutlass.DataType.s4, element_C=cutlass.DataType.s32, element_D=cutlass.DataType.s32, element_accumulator=cutlass.DataType.s32)

plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha=1.0, beta=0.0, print_module=print_module)`

Reported: Traceback (most recent call last): File "/home/chenyj/cutlass_conv.py", line 46, in plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha=1.0, beta=0.0, print_module=print_module) File "/home/chenyj/anaconda3/lib/python3.9/site-packages/cutlass/op/conv.py", line 820, in run self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, File "/home/chenyj/anaconda3/lib/python3.9/site-packages/cutlass/op/conv.py", line 653, in compile self.operation = self.construct( File "/home/chenyj/anaconda3/lib/python3.9/site-packages/cutlass/op/conv.py", line 572, in construct op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0] File "/home/chenyj/anaconda3/lib/python3.9/site-packages/cutlass/library_defaults.py", line 128, in operations raise Exception( Exception: No operations of alignment 32 32 4 found for data type and layout combination (<DataType.s4: 8>, <DataType.s4: 8>, <DataType.s32: 11>) (<LayoutType.TensorNHWC: 9>, <LayoutType.TensorNHWC: 9>). Tried to fall back to alignment 4 4 4, but that was also not compatible. Compatible alignments are dict_keys(['32 32 8', '32 32 16'])

Hi @rawnhenry, Cool, I'll try it!

hwu36 commented 9 months ago
cd /root/TensorRT-LLM-release-0.5.0/cpp/build/tensorrt_llm/plugins && /usr/bin/c++ -DENABLE_BF16 -DENABLE_FP8 -DNVTX_DISABLE -DOMPI_SKIP_MPICXX -DTORCH_CUDA=1 -Dnvinfer_plugin_tensorrt_llm_EXPORTS -I/include -I/root/TensorRT-LLM-release-0.5.0/3rdparty/cutlass/include -I/root/TensorRT-LLM-release-0.5.0/3rdparty/NVTX/include -I/root/TensorRT-LLM-release-0.5.0/3rdparty/json/include -I/root/TensorRT-LLM-release-0.5.0/cpp -I/usr/local/cuda/include -I/usr/include/python3.10 -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/cutlass_extensions/include -I/root/TensorRT-LLM-release-0.5.0/cpp/include -I/opt/hpcx/ompi/include -I/opt/hpcx/ompi/include/openmpi -I/opt/hpcx/ompi/include/openmpi/opal/mca/hwloc/hwloc201/hwloc/include -I/opt/hpcx/ompi/include/openmpi/opal/mca/event/libevent2022/libevent -I/opt/hpcx/ompi/include/openmpi/opal/mca/event/libevent2022/libevent/include -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/bertAttentionPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/gptAttentionCommon -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/gptAttentionPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/identityPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/layernormPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/rmsnormPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/gemmPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/smoothQuantGemmPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/quantizePerTokenPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/quantizeTensorPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/layernormQuantizationPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/rmsnormQuantizationPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/weightOnlyGroupwiseQuantMatmulPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/weightOnlyQuantMatmulPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/lookupPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/quantConvPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/ncclPlugin -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/common -I/root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins -isystem /usr/local/lib/python3.10/dist-packages/torch/include -isystem /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -isystem /usr/local/tensorrt/include -Wno-deprecated-declarations  -DBUILD_SYSTEM=cmake_oss -DENABLE_MULTI_DEVICE=1 -O3 -DNDEBUG -std=c++17 -fPIC -D_GLIBCXX_USE_CXX11_ABI=1 -MD -MT tensorrt_llm/plugins/CMakeFiles/nvinfer_plugin_tensorrt_llm.dir/api/tllmPlugin.cpp.o -MF CMakeFiles/nvinfer_plugin_tensorrt_llm.dir/api/tllmPlugin.cpp.o.d -o CMakeFiles/nvinfer_plugin_tensorrt_llm.dir/api/tllmPlugin.cpp.o -c /root/TensorRT-LLM-release-0.5.0/cpp/tensorrt_llm/plugins/api/tllmPlugin.cpp

I am not sure if you use multiple threads to compile. Is the above command line cause the error? It uses /usr/bin/c++, not nvcc. And your cuda file needs to use .cu instead of .cpp

yuanjiechen commented 9 months ago

I update my kernel file which contain cutlass from .cpp to .cu, but the problem still exist. And I find the root cause of the make error. First, the function call sequence is: quantConvPlugin(enqueue)->quantConvKernels(invokeGeneralQuantConv)->cutlass kernels. Then, if I comment the "include quantConvKernels" in quantConvPlugin.h.21, we can pass the make. But check the make log before and after comment, the quantConvKernels.cu/.h has already done for compile. Besides, I update the kernel namespace to cutlass::conv::kernels, similar to cutlass, still not work. quantConvplugin.zip

Another question, I successfully run the cutlass python with A(int4), B(int4), C(int8), D(int8), accumulator(int32). How to run the dtype same with cpp example, A(int4), B(int4), C(int32), D(int32), accumulator(int32)? This config wil reported an error in cutlass python.

Exception: No operations of alignment 32 32 4 found for data type and layout combination (<DataType.s4: 8>, <DataType.s4: 8>, <DataType.s32: 11>) (<LayoutType.TensorNHWC: 9>, <LayoutType.TensorNHWC: 9>). Tried to fall back to alignment 4 4 4, but that was also not compatible. Compatible alignments are dict_keys(['32 32 8', '32 32 16'])

jackkosaian commented 9 months ago

@yuanjiechen , what values of N H W C K R S P Q are you using in your Python example?

yuanjiechen commented 9 months ago

Hi @jackkosaian , N = 128, H = 32, W = 32, C = 32 K = 128, R = 3, S = 3, P = 32, Q = 32

Update: I solved the problem with A(int4), B(int4), C(int32), D(int32). First, because torch and numpy not support int4, I need compress the int4 data to int8, and interprate the pointer as int4. Second, force the A and B type in cutlass.Datatype.s4, alignment_C in 8 (but I don't know what's the meaning of alignment, the normal value of alignment_C is 4, will reported an error). I compare the conv result with torch output, it's same. So, do we have any future plan to easier suooprt int4 in python? Thanks

jackkosaian commented 9 months ago

Thanks for the details.

Can you also tell me the stride, dilation, and padding values you used?

I'll look into the alignment issue that you mentioned.

Regarding easier support for int4: the CUTLASS Python interface doesn't currently provide data type definitions -- input tensor are expected to be a numpy/torch/etc. tensor with a given element type defined by these libraries. Right now, if numpy/torch/etc. does not provide an implementation of a given data type, a technique similar to what you did must be used. We may reconsider this in the future if need be.

yuanjiechen commented 9 months ago

Hi @jackkosaian , Stride, dilation, and padding are all (1, 1).

Update: is the minimum channels is 32? If I have the input which has 3 channels, do we have some method to run int4 conv2d, rather than padding the input to 32 channels?

github-actions[bot] commented 8 months ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

hwu36 commented 6 months ago

channel 3 is bad since it is 12bit. i would imagine it needs to be at least 32bit. cutlass has some special kernel for small channel size: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/conv/convolution.h#L115 grep this word in conv/device unit tests for examples. pay attention to the alignment settings.

github-actions[bot] commented 5 months ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] commented 2 months ago

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.