HanGuo97 / flute

Fast Matrix Multiplications for Lookup Table-Quantized LLMs
https://arxiv.org/abs/2407.10960
Apache License 2.0
156 stars 5 forks source link

Some questions about kernel #6

Open lswzjuer opened 3 weeks ago

lswzjuer commented 3 weeks ago

What is the difference between qgemm_kernel_generated and qgemm_kernel_raw_generated? I see that the code contains two categories of functions

lswzjuer commented 3 weeks ago

Can't open this link

HanGuo97 commented 3 weeks ago

Great question, FLUTE instantiates a lot of templates and dispatch inputs to them based on shapes. (This behavior will change soon since it's getting hard to maintain.) qgemm_kernel_raw_generated differs from qgemm_kernel_generated in that the former do not have dispatching utility, while the latter do. The former is used for tuning only, and is not part of the release (it's empty.)

lswzjuer commented 3 weeks ago

Can you provide complete tuning instructions? I want to adapt to new shapes on new platforms (such as A100)?

lswzjuer commented 2 weeks ago
It is recommended to add support for typical shapes of Diffusion to expand the usage scenarios, such as

PixArt-sigma

(8,1152,256),
(8,1152,1152),
(8,6912,256),
(300,1152,4096),
(300,1152,1152),
(1024,3456,1152),
(1024,1152,1152),
(449,2304,1152),
(1024,4608,1152),
(1024,1152,4608),
(1024,32,1152),
# SD1.5
(2,1280,320),
(2,1280,1280),
(2,320,1280),
(4096,320,320),
(77,320,768),
(4096,2560,320),
(4096,320,1280),
(2,640,1280),
(1024,640,640),
(77,640,768),
(1024,5120,640),
(1024,640,2560),
(256,1280,1280),
(77,1280,768),
(256,10240,1280),
(256,1280,5120),
(64,1280,1280),
(64,10240,1280),
(64,1280,5120),
# SDXL
(2,1280,320),
(2,1280,1280),
(2,1280,2816),
(2,320,1280),
(2,640,1280),
(4096,640,640),
(77,640,2048),
(4096,5120,640),
(4096,640,2560),
(1024,1280,1280),
(77,1280,2048),
(1024,10240,1280),
(1024,1280,5120),
# SD3
(2,1536,256),
(2,1536,1536),
(2,1536,2048),
(333,1536,4096),
(2,9216,1536),
(4096,1536,1536),
(333,1536,1536),
(4096,6144,1536),
(4096,1536,6144),
(333,6144,1536),
(333,1536,6144),
(2,3072,1536),
(4096,64,1536),
# DIT
(4,1152,256),
(4,1152,1152),
(4,6912,1152),
(1024,1152,1152),
(1024,4608,1152),
(1024,1152,4608),
(4,2304,1152),
(1024,32,1152),
HanGuo97 commented 2 weeks ago

Sorry for getting back to you a bit late, but thanks for the feedback!

Yes, I agree on providing better tuning instructions (we don't have one now, unfortunately...) I'm also trying to have a just-in-time tuning feature so we don't have to tune every possible shape in the world.

I will get back to you when this is done (hopefully won't take too long)!

lswzjuer commented 2 weeks ago

You don’t need to expand all shapes at all. Some personal suggestions: 1: _qgemm_raw replaces qgemm_kernel_generated and becomes a Kernel file, exposing the template_id to the Torch API, so that the best template_id can be selected online for each shape during the warm stage introduced in the Torch layer, with higher flexibility and scalability. 2: It is unwise to use SMs as template parameters, and the scalability is extremely poor 3: It is not recommended that template functions nest specializations of template functions, for example, _qgemm_raw nests qgemm_host. You should first specialize possible qgemm_host instances, and then specialize _qgemm_raw in separate files to reduce compilation pressure.

HanGuo97 commented 2 weeks ago

Thank you for the suggestions --- these are great!

Yes, (1) and (2) are in our roadmap. I need to think a bit about how (1) would complicates its compatibility with torch.compile. Regarding (2), we plan to query SM at runtime.

Could you elaborate a bit on (3)?