pytorch / ao

Custom data types and layouts for training and inference
BSD 3-Clause "New" or "Revised" License
427 stars 54 forks source link

[RFC] Plans for torchao #47

Open supriyar opened 4 months ago

supriyar commented 4 months ago

Summary

Last year, we released pytorch-labs/torchao to provide acceleration of Generative AI models using native PyTorch techniques. Torchao added support for running quantization on GPUs, including int8 dynamic quantization (W8A8) and weight-only quantization (int8 and int4) that were composable with torch.compile. Combined, the APIs launched in torchao were able to power SOTA generative AI models across multiple modalities: Segment Anything, Stable Diffusion, and LLaMa. The results were showcased in these blog posts - https://pytorch.org/blog/accelerating-generative-ai/, https://pytorch.org/blog/accelerating-generative-ai-2/, https://pytorch.org/blog/accelerating-generative-ai-3/

Our investment in torchao is to accelerate Generative AI, using native PyTorch features, ensuring composability with torch.compile.

In 2024, we plan to adopt the following strategy for development of torchao

Let’s dive deeper into some of the coverage areas mentioned above.

Emerging dtypes

Dtypes like NF4, MX4, groupwise quantized int4 are used for implementing various optimization techniques in the models. Last year, we posted a plan on how we wish to support these dtypes in PyTorch. In torchao, we will host tensor subclass based implementation of dtypes, existing examples include uint4 and NF4 that users can use for their own quantization techniques or override the implementation to support other dtypes that might be useful. Moreover, users don’t need to write triton or cuda kernels for their custom dtypes. The implementation can be in python and torch.compile will take care of generating performant kernels under the hood.

Quantization techniques

Quantization can be done on only weights or weights+activations. Typically LLM quantization techniques for BS 1 (memory BW bound) use weight-only quantization techniques. But for larger batch sizes, or longer context length cases or for general throughput bound models quantizing the activations is also beneficial. Quantization, however, impacts the model accuracy and researchers have published techniques to mitigate this accuracy impact which currently exist externally as one repository per technique.

In torchao, we will plan to support the following class of techniques using PyTorch, made available via a simple UX and following the one-file-per-technique principle.

LLM weight only quantization techniques

Post training quantization The two most popular techniques externally are GTPQ and AWQ, available via AutoGPTQ and AutoAWQ which include the technique as well as the performant kernels for faster quantization ops. To that end, we will start by re-implementing the GPTQ and AWQ techniques into torchao using PyTorch via a simple/intuitive UX that supports saving/loading of quantized models, while realizing the memory savings on disk. Some open questions we need to address here include - How much VRAM will be required for different quantization techniques How do we convert to-from weights quantized for different backends (cpu and gpu today use different weight packing format)

In the future, as more interesting and cutting edge techniques are introduced, researchers can directly implement them in torchao or our team can re-implement them in PyTorch.

Weight and activation quantization techniques

Post training quantization We’ve already implemented W8A8 quantization via the int_mm kernel in core. This has shown speedup on models like SAM, SDXL without any impact to model accuracy and can be turned on via a simple one-line UX implemented via module swap or tensor subclass.

However the challenge here is that some smaller layer shapes might not benefit from quantization due to the overhead in quantizing and dequantizing the activation tensors. Users can either statically ignore quantizing these layers or have a higher level API that figures out which layers are sensitive to quantization. We plan to provide a higher level API via the auto quantizer that applies this technique to the layers that stand to benefit the most to provide the benefits of quantization without having to worry too much about the configs to use.

Quantization aware training Techniques here require access to fine-tuning, to tune the model to reduce accuracy impact of quantization. Recently, research like LLM-QAT is promising, showing that we can go down to W4A8 and 4-bit KV cache for LLMs. Moreover, newer lower bit techniques like AQLM, Quip# also include a component of fine-tuning to improve the model accuracy.

We will include the APIs and workflow to enable users to do QAT on LLMs, starting with implementing the LLM-QAT paper in torchao and further extending it to support other dtypes like MX4.

Optimized kernels

Kernels Optimized kernels are key to making models run faster during inference. Today, in core we already have performant kernels like int_mm and 4-bit weight quantization kernels for cpu (via intel) and gpu (via tinygemm). torchao will host performant kernels that will work with different backends with a guide on how to plug in these kernels into PyTorch models via the custom ops API. These kernels will compose with torch.compile, with the expectation that the user is expected to write a meta kernel implementation for this. For executorch, the expectation is that if the user provides a kernel that works with executorch then it should also work in eager mode.

We will also directly engage with the community, to upstream their performant kernels into torchao.

Autotuner

In order to use any CUDA kernel efficiently, we'll need to pick the right kernel hyperparameters. For an eager mode kernel, the same is true as well. A kernel autotuner will help here. We expect that the auto quantizer along with the kernel autotuner will make int8 dynamic quantization and int8/int4 weight-only quantization more usable and performant. A WIP example of what this might look like can be found here.

Release engineering

Shipping optimized, custom kernels requires extensibility mechanisms and release channels. We have custom operator support that integrates broadly, but our release mechanism might need to be optimized. It can be quite difficult to ship custom binaries across a broad range of operating systems and accelerators.

Conversion to-from popular model formats

We can add a conversion util from popular model storage formats like gguf into PyTorch’s state_dict format. This will enable users to take a pre-existing quantized model from llama.cpp and have it run via PyTorch eager mode for desktop cpu/gpu and executorch for on-device cases. We’ll share more details here soon.

Pruning

In addition to quantization, we’ve seen promising results with sparsity as well on GPUs. We will share more updates on what torchao will host for the space of sparsity/pruning in the near future.

We'd love to hear any feedback or questions from the OSS community on this RFC. Thank you!

cc @msaroufim @cpuhrsch @jerryzh168 @HDCharles @andrewor14 @jcaip @jisaacso

mobicham commented 4 months ago

How about HQQ ?

supriyar commented 4 months ago

How about HQQ ?

  • No calibration needed
  • Supports 8,4,3,2,1 bits
  • Very fast quantization, instead of waiting for hours with GPTQ/AWQ
  • The quality is on-par or better than GPTQ/AWQ, especially at low bits
  • bit-unpacking and dequantization CUDA kernels available for all bits
  • Supports backprop for QLoRA training
  • Works with FSDP for distributed QLoRA training

@mobicham HQQ looks great and it would be great to add a PyTorch implementation of this to torchao. Supporting 3, 2, 1 bits is pretty neat and the QLoRA support is useful for us too (we have support for NF4 tensor).

We're currently in the process of adding GPTQ to enable running quantized models on GPU, CPU and executorch (on-device).

Would you be interested in contributing? We have a lightweight API recommendation for new techniques like - https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_api.py#L50. We can also add the kernels into torchao so users can take full advantage of the e2e inference speedups.

jph00 commented 4 months ago

It would be nice to ensure that FSDP works out-of-the-box with quantized models -- we've explained the steps needed to make this work in these articles (and have provided a demonstration script, along with the needed to modifications to HQQ and bitsandbytes):

jph00 commented 4 months ago

I'd be interested in hearing more details on how these changes will be implemented, and how extensible things will be. Will quantization support be added to triton? And torch.compile will be able to convert the python API to triton?

Will there be some fairly low level python API that we could pass all the needed quant state too, so that new algorithms could be largely implemented in pure python?

mobicham commented 4 months ago

Would you be interested in contributing? We have a lightweight API recommendation for new techniques like - https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_api.py#L50. We can also add the kernels into torchao so users can take full advantage of the e2e inference speedups.

Great thanks, I will take a look at it @supriyar !

mobicham commented 4 months ago

I think one important thing to consider is having a standardized Pytorch way of bitpacking/unpacking that all quantization methods can use. Currently, each quantization technique would implement its own bitpacking logic, which makes the CUDA kernels incompatible between different (linear quant) methods.

supriyar commented 4 months ago

Hi @jph00, thanks for your interest in the RFC! Let me try to answer your questions here

Will quantization support be added to triton?

triton already supports 8-bits, and can also be used to do 4-bits too. We gave this a shot last year but saw significantly worse perf. CUTLASS might be a better option here (torch.compile should be able to pick the right kernels for us).

And torch.compile will be able to convert the python API to triton?

Yes, here is an example that @cpuhrsch just added https://github.com/pytorch-labs/ao/pull/60/files. The part where we do the quantization and call torch.compile is here

Will there be some fairly low level python API that we could pass all the needed quant state too, so that new algorithms could be largely implemented in pure python?

yes that is the goal with the proposed API, users can implement their quantize function in python and when they torch.compile it, it should be able to automatically generate triton code for it as long as we have the underlying dtype support in core (like int8, int4). We'll be working on releasing more examples on how to add more custom dtypes, what is the basic interface to implement and how to get it to work with FSDP and torch.compile.

cc @cpuhrsch @msaroufim in case you'd like to add anything else to this.

mobicham commented 4 months ago

triton already supports 8-bits, and can also be used to do 4-bits too. We gave this a shot last year but saw significantly worse perf. CUTLASS might be a better option here (torch.compile should be able to pick the right kernels for us).

@supriyar Initializing an empty tensor first like this is faster than using torch.cat to bit-unpack tensors. torch.compile still struggles with int32 bitpacking, mainly used for 3-bit. We have a barchart that compares inference speed of a quantized Llama2-7B model usingtorch.compile for the dequantization step vs. using CUDA kernels for reference: https://github.com/mobiusml/hqq/tree/master?tab=readme-ov-file#backend

mklasby commented 4 months ago

What is the relationship between this project and torch.ao ? Is torchao a seperate project / development repo?

Happy to have found torchao in any case, lots of goodies...

msaroufim commented 4 months ago

Hi @mklasby (we met last year at NeurIPS) so eventually once things get proven out in torchao they would get upstreamed to torch.ao so the goal is to have this be a standalone repo to have a higher development velocity

mklasby commented 4 months ago

Thanks @msaroufim! Looking forward to contributing to the effort!

ngc92 commented 4 months ago

one more thing to add to the list might be data layouts. For example, for unstructured weight sparsity, storing activations as [batch, ..., features] is much less efficient than [..., batch] on GPUs, because the latter allows for coalesced access patterns. E.g, the Sputnik paper just spares this a single comment

To enable coalesced memory accesses into all input and output matrices, we store dense matrices in row-major layout and sparse matrices in CSR format

But this actually means that you might need to add many transpose operations if the sparse matmuls are interspersed with standard, non-pointwise operations. Similarly, I think, e.g., mixing convolution and attention layers currently would be quite problematic, because one uses [batch, feature, seq] and the other [batch, seq, feature] storage order. And of course, there is the old NCHW vs NHWC channels first/last issue.

mklasby commented 3 months ago

Is there any interest in achieving better alignment of ao.pruning with torch.nn.utils.prune functionalities? For example, torch.nn.utils.prune reparameterizes modules with <param_name>_orig and sets the original param name to the pruned tensor. In contrast, ao.pruning reparameterizes modules with parametrizations.<param_name>.original.

My assumption is that the sparsifier is intended for dynamic mask updates, which is somewhat hacky to perform on top of torch.nn.utils.prune functions currently. However, I think aligning these modules, where possible, will lead to a more compelling and intuitive user experience.

cpuhrsch commented 3 months ago

@mklasby - I think set of functional pruners (similar to torch.nn.functional) would be fairly universal. Then we can build modules that track state needed for incremental pruning or such. Do you think that'd fit the requirements?

mklasby commented 3 months ago

@cpuhrsch Yes, that is essentially what I am envisioning. The functional pruners would be essentially a sophisticated topk function to score parameters based on the specific pruning algorithm and return the updated mask. Any state / buffers required to score the params can be passed from caller to the pruners.

I note that jaxpruner and the cerebras pruning library wrap or subclass the optimizer, respectively, for dynamic sparse training. This is a potential route to consider as well for the modules that track state if we feel that having an additional sparsifier object is less than ideal.

zhexinli commented 3 months ago

Hi, your work is fantastic. Do you have plan to support static quantization? That is to say, not computing amax of activation when running inference, but stead using calibration to pre-compute the quantization scale to reduce dynamic scale overhead? And do you have plan to support more op like conv2d? Thanks!

cpuhrsch commented 3 months ago

Hi @zhexinli - Yes, we want to create a design that can separate calibration from quantization and that should include this as well. We can also add support for conv2d. We have limited support for 1-by-1 convolutions by swapping them for linears.

supriyar commented 3 months ago

Hi, your work is fantastic. Do you have plan to support static quantization? That is to say, not computing amax of activation when running inference, but stead using calibration to pre-compute the quantization scale to reduce dynamic scale overhead? And do you have plan to support more op like conv2d? Thanks!

Hi @zhexinli are you looking for static quant of specific ops like conv/linear or general graph based quantization support? And on what backends?

In addition to what @cpuhrsch said, we have a PT2 export based quantization flow in PyTorch that's based on full-graph capture that you can use to run models on x86 CPU and edge runtimes (https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html)

msaroufim commented 3 months ago

@mobicham are you on CUDA MODE? If not is there an email you could share? Mine is firstnamelastname@meta.com we're quite excited to see an HQQ contribution so wanted to see where your head is at and how we could collaborate together

zhexinli commented 3 months ago

Hi, your work is fantastic. Do you have plan to support static quantization? That is to say, not computing amax of activation when running inference, but stead using calibration to pre-compute the quantization scale to reduce dynamic scale overhead? And do you have plan to support more op like conv2d? Thanks!

Hi @zhexinli are you looking for static quant of specific ops like conv/linear or general graph based quantization support? And on what backends?

In addition to what @cpuhrsch said, we have a PT2 export based quantization flow in PyTorch that's based on full-graph capture that you can use to run models on x86 CPU and edge runtimes (https://pytorch.org/tutorials/prototype/pt2e_quant_ptq_x86_inductor.html)

hi, thanks for your introduction, currently I'm looking for a cuda backend for static quantization.

mobicham commented 3 months ago

@msaroufim sure would to do that! I will send you an email for a follow-up!