pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.39k stars 135 forks source link

[RFC] Add Auto-Round support #533

Closed yiliu30 closed 1 month ago

yiliu30 commented 2 months ago

Hi, here is the INC team from Intel. Thank you for developing this amazing project.

Motivation

Our team has developed Auto-Round, a new weight-only quantization algorithm. It has achieved superior accuracy compared to GPTQ, AWQ, and OmniQuant across 11 tasks, particularly excelling in low-bit quantization (e.g., 2-bits and 3-bits). Auto-Round supports quantization from 2 to 8 bits, involves low tuning costs, and imposes no additional overhead during inference. Key results are summarized below, with detailed information available in our paper, GitHub repository, and Hugging Face low-bit quantization leaderboard.

autoround_res

We would like to contribute this quantization algorithm to torchao to let users benefit from its high accuracy.

The key Idea of Auto-Round

To quantize a given tensor, Auto-Round introduces three trainable parameters (V, α and β) to adjust the rounding value and clipping range. For a given transformers model, Auto-Round quantizes the decoder block one by one, using block-wise output reconstruction error as loss to train these parameters.

autoround_overview

The Modeling User API

We propose the following flow for quantizing a model with Auto-Round, which is similar to the flow of static quantization requiring calibration:

# Step 1. Replace the block with an observed block
# Similar with the `insert_observers_`, but for block
insert_observers_for_block_(m, block_observer, is_block)

# Step 2. calibrating / training
# For capturing the input of block
for _ in range(10):
    m(*example_inputs)

# Step 3. quantize the block
quantize_(m, apply_auto_round, is_observed_block)

Implementation Overview

The high-level idea to implement the above flow is:

1) Replace the model's decoder block with ObservedBlock for capturing the block's input. 2) Calibrate with the user-provided dataset and capture the block's input. 3) Computing the reconstruction error and update V, α and β, then replace the Linear layers in the observed block to QuantizedLinear layers by applying the optimal V, α and β.

The main functions and classes to implement this flow are defined below:

class ObservedBlock(torch.nn.Module):
    # e.g., replace `transformers.models.llama.modeling_llama.LlamaDecoderLayer`
    pass

class QuantizedBlock(torch.nn.Module):
    """All Linears are replaced as Quantized Linear."""
    pass

class ModuleInputCapture(torch.nn.Module):
    """Capture the input of the given module."""
    pass

def insert_observers_for_block_(
    model: torch.nn.Module,
    block_observer: ModuleInputCapture,
    filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
) -> ObservedBlock:
    replacement_fn = lambda m: ObservedBlock.from_float(m, block_observer)
    _replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn)

def apply_auto_round(observed_block: ObservedBlock) -> QuantizedBlock:
    # Call the autoround to execute the optimization process
    import auto_round

    # Start the training process to update the v and alpha and betta
    auto_round.quant_block_(observed_block)

Note, we prefer add auto-round as an dependency. We are also willing to integrate all source code of Auto-Round directly into tochao.

Your feedback is important. Please feel free to comment on the flow mentioned above or suggest additional approaches:). Thank you in advance!

cc @thuang6 @ftian1 @wenhuach21 @hshen14 @jgong5

jerryzh168 commented 2 months ago

This is great to see, awesome work @yiliu30, for integrations I think the flow makes sense. I'm also wondering if it would make sense to have even tighter integrations in terms of the quantized kernel and quantization primitive ops.

today in torchao, we have Affine Quantization: https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization, there are a few things in the stack (from highest level to lower level):

So, ideally I think we would like to integrate with the stack as high as possible (tensor subclass > quant primitive op > uint dtype tensors), the benefit is that your implementation will be able to benefit from whatever perf improvement we might have in these lower level infrastructures and we can optimize these things together as a community and all optimization work can benefit other projects as well.

so for dtype tensors (e.g. uint4 dtype Tensor), I think you should always be able to integrate when that is ready. for quant primitives, I looked at you quant ops: https://github.com/intel/auto-round/blob/e24b9074af6cdb099e31c92eb81b7f5e9a4a244e/auto_round/data_type/int.py#L21 I feel you could probably rewrite that to reuse our quant primitives https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py#L23-L25 (or expand our quant primitives to support that), including the mxfp one: https://github.com/intel/auto-round/blob/e24b9074af6cdb099e31c92eb81b7f5e9a4a244e/auto_round/data_type/mxfp.py (we have mx support in prototype as well: https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats)

If this step is successful, we could probably reuse the same AffineQuantizedTensor in the end as well.

I just take a brief look so there might be things that I missed in terms of the difficulty or feasibility of reusing our quant primitives, please let me know if you have any thoughts on this.

msaroufim commented 2 months ago

Thank you for the thorough issue @yiliu30!

For now we are hoping ao can continue being a project with no dependencies #371 since this makes it easier for other repos to take a dependency on us but here's a few other options instead

  1. We can take a dependency on auto-round in a tutorial only but that example would still need to make it clear to end users why installing both torchao and auto-round is worthwhile and that an integration isn't just a shallow pass through
  2. Thank you for sharing the low bit quantization leaderboard was not aware of it and it's really encouraging to see you propose such a strong algorithm
  3. The flow you described for static quantization makes a lot of sense so please work with @jerryzh168 here, what we can do is basically rewrite the auto-round algorithm to use as much torchao infra as possible and once that's ready we can merge the auto-round algorithm in the prototype namespace, once it's there we can maybe write some blog and then if people use it and it's something that starts needing BC guarantees we'll move it out of prototype

Lmk if that makes sense!

yiliu30 commented 2 months ago

Hi @jerryzh168 and @msaroufim, Thanks for sharing your knowledge about the torchao infra. There are indeed a lot of things that auto-round can reuse. And we understand that there should be particularly strict for the new dependency to avoid hindering adoption.

How about we divide the integration into two phases:

  1. Add auto-round in the prototype, use it to obtain the optimized qweight( with scale and zero values) to achieve the accuracy goal, and leverage torchao's quantized kernel for performance improvements.
  2. Gradually rewrite auto-round to remove dependencies other than pytorch.

I can raise a PR to demonstrate more details.

msaroufim commented 2 months ago

I can take a look at 1 but I can't guarantee I'll merge it some considerations I'll worry about are the size of the auto-round package and whether it causes any issues when installing or importing torchao but if it doesn't take long for you to produce that example it might make the discussion go faster on 2