harrisonvanderbyl / rwkv-cpp-accelerated

A torchless, c++ rwkv implementation using 8bit quantization, written in cuda/hip/vulkan for maximum compatibility and minimum dependencies
MIT License
306 stars 19 forks source link

Questions about int8 quantification. #39

Open zzczzc20 opened 1 year ago

zzczzc20 commented 1 year ago

Dear authors, I am a beginner to the project. And I check the code in "include/rwkv/cuda/rwkv.cu". If my understanding is correct, only the computation inside functions cudac_mm8_one and cuda_mm8_threec are related to int8 quantification and the results are float point numbers. But the calculation in sigmoid and kernel_wkvc_forward are done in float point numbers.

My question is why are these parts not quantified? I have heard of some methods which can quantify the non-linear function with a look-up table. Considering the low speed of exp() function. Is there any methods to replace them with fast substitution?

Best, zzczzc20

harrisonvanderbyl commented 1 year ago

Hi! To answer your question: it's a matter of scale. The mm8 matvec are n^2 operations, so int8 is used to minimise cuda memory usage.

There's no appreciable memory advantage to quantizing other operations.

If you can get faster inference using the methods outlined, please submit a PR :)

zzczzc20 commented 1 year ago

Hi, I am willing to look into it. But I need to benchmark the model so that I can measure the quant loss. Is there any easy method to benchmark the model? Which benchmark do you use? Thanks very much.