NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.63k stars 962 forks source link

Question about ThreadblockShape and WarpShape relations[QST] #404

Closed JustLuoyu closed 2 years ago

JustLuoyu commented 2 years ago

I got two static assertion failures with: // This code section describes the tile size a thread block will compute using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>; // Threadblock tile shape // This code section describes tile size a warp will compute using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // Warp tile shape

The failures are: static_assert(Iterations::kCount, "Number of iterations must be non-zero"); static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1");

Is there any documentation about their relations and limitations?

hwu36 commented 2 years ago

Would you please show us more about your kernel configuration and your problem size? Restrictions are different for different type of kernels.

Common tile sizes are 256x128, 128x256, 128x128, 128x64, 64x128, 256x64, 64x256, 128x64, 64x256, 64x64 Common warp tile sizes are 64x64, 64x32, 32x64, 32x32 We tried to launch 4 or 8 warps. The bigger warp tile size, the better. But for some kernels, large tile sizes will use too much registers or shared memory. so we have to deal with them case by case.

As to k dimension, it is usually 4x of instruction k before ampere. On ampere, it can be 2x or 4x of instruction K.

JustLuoyu commented 2 years ago

I don't get it because these are compilation-time errors. So problem size may not be the issue, it's unknown at this time. My inst tile is :

using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; // TensorCore instruction shape

I think both meet your recommendation: warp tile 32x32 and 4 warps. I can bypass this assertion with the WarpShape:

using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; // Warp tile shape

But I don't know why this works.

JustLuoyu commented 2 years ago

My main purpose is to create a kernel processing large NHW and small C K convolutions. So I'd like to enlarge the N dimension of GEMM and reduce the N, K dimensions of GEMM.

hwu36 commented 2 years ago

I can help you pick the tile size if you tell me your problem size. From what you said, I suggest 256x64 tb size with 64x64 warp size; or 128x64 tb size with 64x32 warp size.

You can also use our cutlass profiler to run all eligible kernels for your problem size to pick the best. The doc of the profiler is https://github.com/NVIDIA/cutlass/blob/master/media/docs/profiler.md

JustLuoyu commented 2 years ago

The problem size is: activation's shape(nhwc)=1x540x960x16, kernel shape(krsc)=16x3x3x16, dilation=1, stride=1, pad=1.

How to run all eligible kernels besides set cta_m, cta_k, and cta_n manually?

JustLuoyu commented 2 years ago

cutlass_profiler.exe --mode=enumerate. I only see Conv2d kernels with 256x128 tb size are valid for f16 calc. So I started to create cutlass kernels manually to profile this problem size. Then I met this question.

hwu36 commented 2 years ago

By default, we only enable the largest tile size for each type of kernel in the profiler because compiling all supported kernels take too much time.

For example, to enable all SM80 fprop kernels that use fp16 input, fp32 accumulation, fp16 output, you can do

cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*fprop_optimized_f16_*_align8

You can change the above command based on your arch, data type, etc.

JustLuoyu commented 2 years ago

I get your point. But I wish to get an explanation of the restrictions of tb size and warp size. I'm trying to implement a convolution layer's tuning process like TensorRT, so I need to know the exact valid sizes.

JustLuoyu commented 2 years ago

Compilation passed, but the wrong result

// This code section describes the tile size a thread block will compute
using ThreadblockShape = cutlass::gemm::GemmShape<128, 16, 16>;  // Threadblock tile shape

// This code section describes tile size a warp will compute
using WarpShape = cutlass::gemm::GemmShape<128, 16, 16>;         // Warp tile shape

Compilation passed, the right result

// This code section describes the tile size a thread block will compute
using ThreadblockShape = cutlass::gemm::GemmShape<128, 32, 16>;  // Threadblock tile shape

// This code section describes tile size a warp will compute
using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>;         // Warp tile shape
hwu36 commented 2 years ago

When we wrote these kernels, we only have the tile sizes I listed above in mind. However, our implementation is general enough that sometimes other tile sizes work for some kernels and they are used by folks that are willing to take extra effort. So we don't want to add static_assert to rule them out.

If you want to be safe, you can just use the tile sizes used by the profiler.

If you point me to the line of the code of the asserts you mentioned earlier, I can explain a little bit more about them to you.

JustLuoyu commented 2 years ago

Good to hear you say so! These combinations seem reasonable to me but resulted in static_assert failure.

Case 1: 128x16x16 tb size, 32x16x16 warp size, 8x8x4 inst size.

assert1: pitch_linear_thread_map.h(251): error : static assertion failed with "ShapeInAccesses must be divisible by WarpThreadArrangement." assert2: pitch_linear_thread_map.h(289): error : static assertion failed with "Number of iterations must be non-zero" assert3: conv2d_fprop_activation_tile_access_iterator_optimized.h(97): error : static assertion failed with "Require Iterations::kContiguous == 1" assert4: conv2d_fprop_filter_tile_access_iterator_optimized.h(95): error : static assertion failed with "Require Iterations::kContiguous == 1" assert5: mma_tensor_op_sm70.h(130): error : static assertion failed with "Shape must be a multiple of InterleavedTileShape."

why sm70 joined my compilation while I only selected sm75?

case 2: 128x16x16 tb size, 32x16x16 warp size, 16x8x8 inst size.

assert1: pitch_linear_thread_map.h(289): error : static assertion failed with "Number of iterations must be non-zero" assert2: conv2d_fprop_filter_tile_access_iterator_optimized.h(95): error : static assertion failed with "Require Iterations::kContiguous == 1" warnings: conv2d_fprop_filter_tile_access_iterator_optimized.h(200): warning #179-D: right operand of "%" is zero regular_tile_access_iterator_tensor_op.h(546): warning #39-D: division by zero

hwu36 commented 2 years ago

128x16x16 is not the tile sizes I listed above. The main problem is that B threadblock size which is 16x16 is too small for every thread to have some data to load.

There are 16x16=256 elements per threadblock. You have 128 threads because (128/32)x(16/16)=4 warps. Then every thread has 256 / 128 = 2 elements to loads. 2 fp16 elements are 4B, we require every thread to load at least 16B data.

JustLuoyu commented 2 years ago

we require every thread to load at least 16B data.

This is very useful information, thanks a lot!

foreverlms commented 1 week ago

Nice explanation.