Open jeromeku opened 5 months ago
Hi, Marlin only uses ldmatrix
for the activations, as the weights are already preshuffled optimally for both dequantization and tensor core fragment layouts. You can find some more detailed description of how this format works here https://github.com/IST-DASLab/marlin/issues/12.
Marlin is completely independent of GPTQ, the model needs to be quantized symmetrically either with groupsize 128 or row-wise (how you produced this model doesn't matter to Marlin); then you can preprocess the weights and use Marlin kernels. Zero-points are currently not supported, the reasons for this are discussed here https://github.com/IST-DASLab/marlin/issues/5#issuecomment-1934082099.
@efrantar
Thank you for taking the time to explain.
Have you looked into Cutlass
, specifically the 3.x
API that introduced the CuTe
abstractions for tensor thread-value manipulation / mapping? Wondering if it could potentially help generalize / extend the handcrafted code in Marlin without sacrificing performance.
@efrantar
Awesome work -- always enjoy your research on and implementation of efficient model inference.
I was hoping that you could shed some light on the logic of the packing step?
My understanding is that the individual int4 values need rearranged in order to use the fast unpack / convert functions from FasterTransformer.
Is the subsequent interleaving such that
ldmatrix
can be used on these packed values such that each thread holds the necessary values formma.sync
? Typicallyldmatrix
is used onfp16 / bf16
types, but in this case the weights are sub-byte types, hence the additional preprocessing required for efficient shared -> register copy. I know FasterTransformer has its own formatting logic as a workaround for this issue; I have yet to find a general solution to efficiently leveraging tensorcore primitives on sub-byte types without preprocessing weights to a custom format.Theoretically, if I were to preprocess the weights of a
non-GPTQ
int4
model using the packing function -- i.e., any groupwise quantization method that yields4b
weights along with group scales and zeros -- would I be able to use theMarlin
kernel on such model? If not, what changes would need to be made?Many thanks!