NVIDIA / cutlass

CUDA Templates for Linear Algebra Subroutines
Other
5.66k stars 971 forks source link

How to implement int8 complex GemmBatched? #856

Closed zhangyilalala closed 1 year ago

zhangyilalala commented 1 year ago

Hi,

I want to implement int8 complex GemmBatched for my project to run on sm70 device.(uint8 * uint8 = uint32)

May I ask what's the best way to do it?

hwu36 commented 1 year ago

If you need to know how complex gemm works, I think you can take a look at how complex<float>works.

As to the implementation, the global memory->shared memory->register file is the same as fp32 simt gemm since the data are all 32bit. After the data are loaded into the register file, you need to unpack the data into real and image part and do the mma which is similar to complex<float>

zhangyilalala commented 1 year ago

Thanks for the reply ~

I try to use later cutlass: : gemm: : device: : GemmBatched, below such setting data type, the result is correct.

using ElementInputA = cutlass::complex<char>;
using ElementInputB = cutlass::complex<char>;
using ElementOutput = cutlass::complex<int>;
using ElementAccumulator = cutlass::complex<int>;

However, it turned out to take about the same amount of time as float.

using ElementInputA = cutlass::complex<float>;
using ElementInputB = cutlass::complex<float>;
using ElementOutput = cutlass::complex<float>;
using ElementAccumulator = cutlass::complex<float>;

May I ask why?

hwu36 commented 1 year ago

It is compute bound. Perf doesn't have much to do with data types.

zhangyilalala commented 1 year ago

I see. I want to use int8 to speed up batch complex matrix multiplication. Is there a suitable ready-made template interface? Or do I have to do it myself?

hwu36 commented 1 year ago

You said you already get it working above, Right?

zhangyilalala commented 1 year ago

Yes, I used the cutlass: : gemm: : device: : GemmBatched interface and it worked, but my original intention was to use the int8 data type, which would speed up the calculation compared to float. The main thing is that this is not accelerated at the moment, what is the best way?

hwu36 commented 1 year ago

When the problem size is small, it is memory bound and int8 can speedup. When the problem size is big, it is compute bound, it is not going to speedup.

zhangyilalala commented 1 year ago

I see. So for large multiplications like Transformers, quantizing computations to int8 data types doesn't have any effect on speedup?

hwu36 commented 1 year ago

On newer arch, int8 can be accelerated by tensor cores. Sm70 doesn't have such tensor cores.

zhangyilalala commented 1 year ago

Ok, so int8 is speedup over float on tensor cores? But there's no such effect on cuda core?

hwu36 commented 1 year ago

Correct

zhangyilalala commented 1 year ago

Correct

I see. Thanks for your reply!

zhangyilalala commented 1 year ago

Sorry, one more question

I ran the deep learning CNN network on the hardware of sm70 and quantified the model to int8. The speed will be improved. I think this bottom layer is also matrix multiplication, do you know why?

hwu36 commented 1 year ago

if you don't do complex gemm but scalar gemm, int8 is faster than fp32

zhangyilalala commented 1 year ago

Ok,get it.  thanks.

---Original--- From: "Haicheng @.> Date: Fri, Mar 10, 2023 11:35 AM To: @.>; Cc: @.**@.>; Subject: Re: [NVIDIA/cutlass] How to implement int8 complex GemmBatched?(Issue #856)

if you don't do complex gemm but scalar gemm, int8 is faster than fp32

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>