NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.61k stars 955 forks source link

Using CUTLASS to benchmark plain CUDA performance #130

Closed richardschulze closed 4 years ago

richardschulze commented 4 years ago

Hi,

I am trying to see what the best performance for row-major SGEMM for 4 specific input sizes is, when only using plain CUDA (no tensor cores, no intrinsics). This is useful to me, because I want to use that as a baseline for comparisons with my own CUDA implementation and I think CUTLASS is currently the best performing CUDA implementation for GEMM. I tried using the CUTLASS profiler, but the profiler does not seem to support row-major output matrices. I therefor copied the 00_basic_gemm example and edited it to use row-major matrices. This works fine, but CUTLASS seems to reach quite low performance on a V100 for these input sizes:

M=16, N=1000, K= 2048: 167 GFLOPs (1.1% of peak single-precision performance)
M= 1, N=1000, K= 2048:  10 GFLOPs (0.1% of peak single-precision performance)
M=16, N=4096, K=25088: 706 GFLOPs (4.5% of peak single-precision performance)
M= 1, N=4096, K=25088:  44 GFLOPs (0.3% of peak single-precision performance)

I would appreciate your help in figuring out if I used CUTLASS correctly and if this is the highest performance CUTLASS can reach using plain CUDA.

Many thanks in advance!

hwu36 commented 4 years ago

Hi,

You can change this line https://github.com/NVIDIA/cutlass/blob/4dac7490e6a846893595de352c57f2bf118796c1/tools/library/scripts/generator.py#L115 to (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), .

Your problem is small M, you can take a look at the tile sizes used in https://github.com/NVIDIA/cutlass/blob/4dac7490e6a846893595de352c57f2bf118796c1/test/unit/gemm/device/simt_sgemm_tt_sm50.cu to change the tile sizes in the profiler (https://github.com/NVIDIA/cutlass/blob/4dac7490e6a846893595de352c57f2bf118796c1/tools/library/scripts/generator.py#L138-L143). I think small tile sizes in the M dimension can help you.

I think you also need to turn on split-k (https://github.com/NVIDIA/cutlass/tree/master/examples/06_splitK_gemm) for small M, by setting split_k_slices (see the helper message of the profiler).

You can also try to change this line https://github.com/NVIDIA/cutlass/blob/4dac7490e6a846893595de352c57f2bf118796c1/tools/library/scripts/generator.py#L39 to swizzling_functor = SwizzlingFunctor.Identity1 because I don't think Identity8 can help you any.

However, for this problem size, I don't think you would see very high performance.

richardschulze commented 4 years ago

Many thanks for your detailed comment. I tried changing line 115 in generator.py, but after running cmake and make I get the following errors: log.txt

I also tried adding a new line instead of changing the existing one. That compiles, but I still cannot profile with row-major C.

hwu36 commented 4 years ago

I tried locally and looks fine.

[haichengw@computelab-build-02 build70]$ make cutlass_profiler -j8
Scanning dependencies of target cutlass_library_objs
[ 12%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/all_gemm_operations.cu.o
[ 12%] Building CXX object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/initialize_all.cpp.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x32_8x2_nn_align1.cu.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_64x64_8x2_nn_align1.cu.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x64_8x2_nn_align1.cu.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_64x128_8x2_nn_align1.cu.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x128_8x2_nn_align1.cu.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_32x128_8x2_nn_align1.cu.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x128_8x2_nt_align1.cu.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x64_8x2_nt_align1.cu.o
[ 25%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_64x128_8x2_nt_align1.cu.o
[ 37%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_64x64_8x2_nt_align1.cu.o
[ 37%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x32_8x2_nt_align1.cu.o
[ 37%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_32x128_8x2_nt_align1.cu.o
[ 37%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x128_8x2_tn_align1.cu.o
[ 37%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x64_8x2_tn_align1.cu.o
[ 50%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_64x128_8x2_tn_align1.cu.o
[ 50%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_64x64_8x2_tn_align1.cu.o
[ 50%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x32_8x2_tn_align1.cu.o
[ 50%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_32x128_8x2_tn_align1.cu.o
[ 50%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x128_8x2_tt_align1.cu.o
[ 62%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x64_8x2_tt_align1.cu.o
[ 62%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_64x128_8x2_tt_align1.cu.o
[ 62%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_64x64_8x2_tt_align1.cu.o
[ 62%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_128x32_8x2_tt_align1.cu.o
[ 62%] Building CUDA object tools/library/CMakeFiles/cutlass_library_objs.dir/generated/gemm/cutlass_simt_sgemm_32x128_8x2_tt_align1.cu.o
[ 75%] Built target cutlass_library_objs
[ 75%] Linking CUDA device code CMakeFiles/cutlass_lib.dir/cmake_device_link.o
[ 75%] Linking CXX shared library libcutlass.so
[ 75%] Built target cutlass_lib
[ 75%] Linking CUDA device code CMakeFiles/cutlass_profiler.dir/cmake_device_link.o
[ 75%] Linking CXX executable cutlass_profiler
[100%] Built target cutlass_profiler

[haichengw@computelab-build-02 build70]$ cat tools/library/generated/gemm/cutlass_simt_sgemm_64x128_8x2_nn_align1.cu

/*
  Generated by gemm_operation.py - Do not edit.
*/

///////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/arch/wmma.h"
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"

#include "library_internal.h"
#include "gemm_operation.h"

///////////////////////////////////////////////////////////////////////////////////////////////////

// Gemm operator cutlass_simt_sgemm_64x128_8x2_nn_align1
using cutlass_simt_sgemm_64x128_8x2_nn_align1_base = 
  typename cutlass::gemm::kernel::DefaultGemmUniversal<
    float, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1,    // transposed B operand
    float, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 1,    // transposed A operand
    float, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassSimt,
    cutlass::arch::Sm50,
    cutlass::gemm::GemmShape<64, 128, 8>,
    cutlass::gemm::GemmShape<32, 64, 8>,
    cutlass::gemm::GemmShape<1, 1, 1>,
    cutlass::epilogue::thread::LinearCombination<
      float,
      1,
      float,
      float
    >,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
    2,
    cutlass::arch::OpMultiplyAdd
>::GemmKernel;

// Define named type
struct cutlass_simt_sgemm_64x128_8x2_nn_align1 : 
  public cutlass_simt_sgemm_64x128_8x2_nn_align1_base { };

///////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace library {

///////////////////////////////////////////////////////////////////////////////////////////////////

void initialize_cutlass_simt_sgemm_64x128_8x2_nn_align1(Manifest &manifest) {

  manifest.append(new GemmUniversalOperation<
      cutlass::gemm::device::GemmUniversalAdapter<cutlass_simt_sgemm_64x128_8x2_nn_align1>
    >("cutlass_simt_sgemm_64x128_8x2_nn_align1"));

}

///////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace library
} // namespace cutlass

///////////////////////////////////////////////////////////////////////////////////////////////////

My cmake line is cmake .. -DCUTLASS_NVCC_ARCHS=70 -DCUTLASS_LIBRARY_KERNELS=sgemm

d-k-b commented 4 years ago

@richard-wwu, try deleting your build directory when you make the change if you haven't, incremental builds with changes to the cmake-generated intermediate files doesn't always work properly.

richardschulze commented 4 years ago

Thank you for the suggestions. Unfortunately the problem still persists. I started with a clean repository from scratch:

git clone https://github.com/NVIDIA/cutlass.git
cd cutlass
// edit generator.py
mkdir build && cd build
cmake .. -DCUTLASS_NVCC_ARCHS=70 -DCUTLASS_LIBRARY_KERNELS=sgemm
make cutlass_profiler

I still get the exact same error log I posted before. The file tools/library/generated/gemm/cutlass_simt_sgemm_64x128_8x2_nn_align1.cu looks exactly the same as the one @hwu36 posted.

I am on CentOS 7 using GCC 8.2.0 and CUDA 10.1.

hwu36 commented 4 years ago

I can reproduce your failure if I switch to use CUDA 10.1. The failure happens to TT layout instead of NN layout which is very strange.

You can call cmake like this to only compile NN problems cmake .. -DCUTLASS_NVCC_ARCHS=70 -DCUTLASS_LIBRARY_KERNELS=sgemm*nn

Or you can just update your compiler to cuda11.0 which should be good for the performance.

richardschulze commented 4 years ago

I upgraded to CUDA 11.0 but still get the same error.

Only compiling the NN problems allows me to compile without errors, but that does not seem to compile a kernel for row-major A, B, and C. Executing

./cutlass_profiler --kernels=gemm --m=16 --n=1000 --k=2048 --A=f32:row --B=f32:row --C=f32:row

immediately returns and

./cutlass_profiler --kernels=gemm --m=16 --n=1000 --k=2048 --A=f32 --B=f32 --C=f32

outputs

...

=============================
  Problem ID: 1

        Provider: CUTLASS
   OperationKind: gemm
       Operation: cutlass_simt_sgemm_128x64_8x2_nn_align1

          Status: Success
    Verification: ON
     Disposition: Passed

          cuBLAS: Passed

       Arguments: --gemm_kind=universal --m=16 --n=1000 --k=2048 --A=f32:column --B=f32:column --C=f32:column --alpha=1  \
                  --beta=0 --split_k_slices=1 --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=64 --cta_k=8  \
                  --stages=2 --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024  \

           Bytes: 8387072  bytes
           FLOPs: 65568000  flops

         Runtime: 0.271053  ms
          Memory: 28.8175 GiB/s

            Math: 241.901 GFLOP/s

...

Problem,Provider,OperationKind,Operation,Disposition,Status,gemm_kind,m,n,k,A,B,C,alpha,beta,split_k_slices,batch_count,op_class,accum,cta_m,cta_n,cta_k,stages,warps_m,warps_n,warps_k,inst_m,inst_n,inst_k,min_cc,max_cc,Bytes,Flops,Runtime,GB/s,GFLOPs
1,CUTLASS,gemm,cutlass_simt_sgemm_128x128_8x2_nn_align1,passed,success,universal,16,1000,2048,f32:column,f32:column,f32:column,1,0,1,1,simt,f32,128,128,8,2,4,2,1,1,1,1,50,1024,8387072,65568000,0.383601,20.3625,170.928
1,CUTLASS,gemm,cutlass_simt_sgemm_128x64_8x2_nn_align1,passed,success,universal,16,1000,2048,f32:column,f32:column,f32:column,1,0,1,1,simt,f32,128,64,8,2,2,2,1,1,1,1,50,1024,8387072,65568000,0.271104,28.8121,241.856
1,CUTLASS,gemm,cutlass_simt_sgemm_64x128_8x2_nn_align1,passed,success,universal,16,1000,2048,f32:column,f32:column,f32:column,1,0,1,1,simt,f32,64,128,8,2,2,2,1,1,1,1,50,1024,8387072,65568000,0.39383,19.8336,166.488
1,CUTLASS,gemm,cutlass_simt_sgemm_64x64_8x2_nn_align1,passed,success,universal,16,1000,2048,f32:column,f32:column,f32:column,1,0,1,1,simt,f32,64,64,8,2,2,1,1,1,1,1,50,1024,8387072,65568000,0.330906,23.6051,198.147
1,CUTLASS,gemm,cutlass_simt_sgemm_128x32_8x2_nn_align1,passed,success,universal,16,1000,2048,f32:column,f32:column,f32:column,1,0,1,1,simt,f32,128,32,8,2,2,1,1,1,1,1,50,1024,8387072,65568000,0.245883,31.7674,266.664
1,CUTLASS,gemm,cutlass_simt_sgemm_32x128_8x2_nn_align1,passed,success,universal,16,1000,2048,f32:column,f32:column,f32:column,1,0,1,1,simt,f32,32,128,8,2,1,2,1,1,1,1,50,1024,8387072,65568000,0.388731,20.0938,168.672

This means the profiler does not execute row-major GEMMs although tools/library/generated/gemm/cutlass_simt_sgemm_64x128_8x2_nn_align1.cu specifies row-major matrices, correct? Am I missing something?

hwu36 commented 4 years ago

Sorry, forget about what I said earlier and I think you don't need to change any code.

Here is the story. cuBLAS only supports column major output. CUTLASS only supports row major output. To match these two libraries, CUTLASS does this mapping.

C^T = (A x B)^T = B^T x A^T

When cuBLAS runs nn layout, it means A:col x B:col -> C:col, but internally cutlass profiler maps it to B:row x A:row -> C:row.

So, you just need to run nn layout to get cutlass performance of row x row -> row. The information showed in the screen is to match cuBLAS.

richardschulze commented 4 years ago

I see. Just to make sure I fully understand, when I run

./cutlass_profiler --kernels=gemm --m=16 --n=1000 --k=2048 --A=f32:column --B=f32:column --C=f32:column

that is CUTLASS' time for GEMM on row-major A, B, C, correct?

hwu36 commented 4 years ago

Yes, you are correct. Sorry for the confusion.

richardschulze commented 4 years ago

No worries, many thanks for your help.

manishucsd commented 4 years ago

This comment is helpful here: https://github.com/NVIDIA/cutlass/blob/master/media/docs/gemm_api.md#efficient-epilogue