pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
683 stars 88 forks source link

Refactor smoothquant implementation to use tensor subclasses #528

Open jerryzh168 opened 1 month ago

jerryzh168 commented 1 month ago

Smoothquant implements input-weight equalization, currently the implementation in torchao is using module swap, but it can be refactored to use tensor subclass, and also to use AffineQuantizedTensor so that we can consolidate the performance optimizations to one place. We can use static quantization flow as an example: https://github.com/pytorch/ao/pull/487.

Main benefit of the refactor would be: (1) aligning model level APIs (2) easier deserialization story (https://pytorch.org/ao/stable/serialization.html#what-happens-when-deserializing-an-optimized-model), you can load the quantized state dict to original model directly and get a model ready for inference

Overview

Here is the top level API for smoothquant: https://github.com/pytorch/ao/tree/main/torchao/quantization#to-be-moved-to-prototype-a8w8-dynamic-quantization-with-smoothquant

It follows our calibration flow (static quant flow) pretty closely: https://github.com/pytorch/ao/blob/afde1755d906ad644e04835675e7856d72c3c87b/tutorials/calibration_flow/static_quant.py#L121-L134

How to implement it in torchao

Similar to static quantization flow, at the high level, we can have two steps.

Step 1. Inserting Observers

First step is to insert observers that records the running absolute max value: https://github.com/pytorch/ao/blob/afde1755d906ad644e04835675e7856d72c3c87b/torchao/quantization/smoothquant.py#L146-L147

we can create a function insert_smoothquant_observers_ similar to https://github.com/pytorch/ao/blob/afde1755d906ad644e04835675e7856d72c3c87b/tutorials/calibration_flow/static_quant.py#L37

Step 2. Convert to AffineQuantizedTensor with a new layout

After we collected the stats, we can convert the floating point weight to AffineQuantizedTensor with a new LayoutType and AQTLayout, with an extra equalization_scale Tensor, this can share the same implementation as AWQ I think, although with different dtypes (int8). Example conversion code: https://github.com/pytorch/ao/blob/afde1755d906ad644e04835675e7856d72c3c87b/tutorials/calibration_flow/static_quant.py#L46-L63

In terms of model level API, we can implement some helper function like https://github.com/pytorch/ao/blob/afde1755d906ad644e04835675e7856d72c3c87b/torchao/quantization/quant_api.py#L363 to support any configurations.

Logistics (Code Location, Test and Benchmarks)

Please create an smoothquant folder under https://github.com/pytorch/ao/tree/main/torchao/prototype The flow and layout implementation can be in separate files, e.g. flow.py, layout.py (there might be some missing extension points of AffineQuantizedTensor, but we'll work on these at the same time)

For Testing, please create a test_smoothquant.py in https://github.com/pytorch/ao/tree/main/test/prototype and move the tests from https://github.com/pytorch/ao/blob/afde1755d906ad644e04835675e7856d72c3c87b/test/integration/test_integration.py#L159 to that file

For e2e flow demo, please add a smoothquant.py in https://github.com/pytorch/ao/tree/main/tutorials/calibration_flow following the static quant example, please show the benchmarking result as well (since we are using optimized kernel) following https://github.com/pytorch/ao/tree/main/torchao/quantization#quantization-flow-example

Last step is to test this with llama2/llama3 following instructions in https://github.com/pytorch/ao/tree/main/torchao/_models/llama and measure the metrics in https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks if you have GPU machines. For smoothquant, you can test in CPU machines and add results in the quantization README as well

References

jerryzh168 commented 1 month ago

cc @Xia-Weiwen will work on this