IST-DASLab / marlin

FP16xINT4 LLM inference kernel that can achieve near-ideal ~4x speedups up to medium batchsizes of 16-32 tokens.
Apache License 2.0
544 stars 42 forks source link

[QST] Weight Format & GEMM #18

Open jeromeku opened 5 months ago

jeromeku commented 5 months ago

@efrantar

Awesome work -- always enjoy your research on and implementation of efficient model inference.

I was hoping that you could shed some light on the logic of the packing step?

Many thanks!

efrantar commented 5 months ago

Hi, Marlin only uses ldmatrix for the activations, as the weights are already preshuffled optimally for both dequantization and tensor core fragment layouts. You can find some more detailed description of how this format works here https://github.com/IST-DASLab/marlin/issues/12.

Marlin is completely independent of GPTQ, the model needs to be quantized symmetrically either with groupsize 128 or row-wise (how you produced this model doesn't matter to Marlin); then you can preprocess the weights and use Marlin kernels. Zero-points are currently not supported, the reasons for this are discussed here https://github.com/IST-DASLab/marlin/issues/5#issuecomment-1934082099.

jeromeku commented 5 months ago

@efrantar

Thank you for taking the time to explain.

Have you looked into Cutlass, specifically the 3.x API that introduced the CuTe abstractions for tensor thread-value manipulation / mapping? Wondering if it could potentially help generalize / extend the handcrafted code in Marlin without sacrificing performance.