IST-DASLab / gptq

Code for the ICLR 2023 paper "GPTQ: Accurate Post-training Quantization of Generative Pretrained Transformers".
https://arxiv.org/abs/2210.17323
Apache License 2.0
1.89k stars 151 forks source link

How to apply 3/4-bit quantization to computer vision models? #14

Closed zshn25 closed 1 year ago

Godofnothing commented 1 year ago

You can do it in a similar fashion.

Take the dataset of interest, select some subset from it, say 1024 images. And propagate the activations in a same way as here for OPT,Bloom and LLaMa models. For ViTs procedure is very similar and for ResNets you would need to wrap conv layers with GPTQ class and run quantize on it.

However, for vision models 3/4 bit quantization would likely degrade the performance significantly and there is usually no much need in it, since they are already quite small unless you quantize recent ViT-22B.

efrantar commented 1 year ago

Adding to what Denis said, you may also want to take a look at our OBC work https://github.com/IST-DASLab/OBC, where we study also vision model compression with a more accurate but less scalable predecessor of GPTQ.

ThisisBillhe commented 1 year ago

There is a quant_cuda_kernel for quant_linear module, but it can not be applied to conv layers right? So there can only be fake quantization for conv layers?

You can do it in a similar fashion.

Take the dataset of interest, select some subset from it, say 1024 images. And propagate the activations in a same way as here for OPT,Bloom and LLaMa models. For ViTs procedure is very similar and for ResNets you would need to wrap conv layers with GPTQ class and run quantize on it.

However, for vision models 3/4 bit quantization would likely degrade the performance significantly and there is usually no much need in it, since they are already quite small unless you quantize recent ViT-22B.

efrantar commented 1 year ago

Yes, our sample kernels here are designed for the memory-bound single-token linear-layer case that occurs for very large generative NLP models. Probably the most simple option for running quantized conv layers would be to fully decompress them on the fly before executing the corresponding convolution. This would still give you memory savings during inference (you only need to have a single dequantized layer in memory at a time) but will not give you any speedups (in any case, vision models are usually not memory-bound so weight-only quantization is unlikely to give you speedups).