pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
766 stars 97 forks source link

GPTQ implementation with tensor subclasses #577

Open HDCharles opened 1 month ago

HDCharles commented 1 month ago

Problem: The current implementation of GPTQ (A technique used to improve quantization accuracy) relies on model tracing which cannot handle certain op sequences of various models, as an example, the gpt-fast llama implementation’s kv_cache update causes an issue. We worked around this for our version where we use a flag to determine whether we use the fast op sequence or the one that works with tracing.

Goal: re-implement GPTQ as a technique which uses tensor subclasses to avoid this issue.

Background: GPTQ requires us to, for each linear, 1) track all activations going to that linear 2) apply the GPTQ algorithm to update that linear’s weight 3) use the updated weight to generate outputs for that linear 4) repeat for the next linear

The main complication is that, to track n activations going to each linear, we’d normally run the model for n inputs. But if we update the weight of the 1st linear (and want to use the updated weight for the activations of the 2nd linear), we’d then have to run the model n more times. Meaning we’d have to run the model n*L (L=layers) times in total, which is extremely slow.

Instead, we want to run the model for each input, but ONLY up to the first linear, then pause, do the algorithm to update the weight, get the outputs for the updated weight and then, unpause and continue on until we hit the next linear….etc.

Previously we did this using tracing and a custom fx.Interpreter but as mentioned above, its a bit fragile. Instead we want to use a tensor subclass which will operate as a "multi-tensor", containing the tensors from several inputs. Each time it encounters a normal op, it will do the normal op to each constituent tensor, and generate a multi-tensor output so that it will propagate through the network. When it encounters a linear op, it will instead do the GPTQ weight update...etc outlined above.

Starting Point: We have a really good starting point this, here is a gist with a working implementation of a MultiTensor subclass, showing how it is intended to work as far as running the model while propagating a multitensor (it doesn’t contain the weight update code). The task is to take this building block which handles all the multi-tensor propagation and add in the code for the actual GPTQ algorithm which happens when it encounters a linear op. This can be taken more or less directly from the existing implementation. We also need you to adapt the other scaffolding bits like GPTQQuantizer…etc to get it all working for int4 weight only quantization. Ideally getting this test to work with the updated api (though perhaps with a less trivial calibration limit set)

danielpatrickhug commented 1 month ago

@HDCharles Hi, I would like to work on this issue. I'm currently reviewing the multi tensor class and the current GPTQ implementation.