pytorch / ao

Create and integrate custom data types, layouts and kernels for training and inference
BSD 3-Clause "New" or "Revised" License
394 stars 53 forks source link

[RFC] Tensor Subclass based Quantization API #391

Open jerryzh168 opened 2 weeks ago

jerryzh168 commented 2 weeks ago

Status: Draft Updated: 06/17/2024

Objective

In this doc we’ll talk about Tensor subclass based quantization API for modeling users and developers.

Modeling User API

Modeling users refer to people who use quantization APIs to quantize their model for speed up, memory saving, power saving etc. Our main goal for modeling user API is for that to be easy to use without the need to fully understand technical details.

We expect users to use two types of APIs: (1). Manual quantization API with a direct API call (2). Automatic quantization API based on some objectives. It looks like the following:

1. Manual API call

from torchao.quantization import quantize, int4_weight_only

 # 1. quantize with name of some pre-packaged quantization technique
 # easiest way to apply quantization is to use string name for
 # some pre registered techniques
 m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
 # current options are: int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight
 m = quantize(m, int4_weight_only())

# 2. customize some methods
# We also provide functions for these techniques so people can customize them
from torchao.quantization.quant_api import int4wo
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
m = quantize(m, int4wo(groupsize=32))

# 3. write your own new apply_tensor_subclass
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
# on weight

# let's say user wants to use a new quantization technique/dtype that exposed factory function: to_affine_quantized

from torchao.dtypes import to_affine_quantized

# 3.1. weight only quantization with `to_affine_quantized`
def apply_to_affine_quantized(weight):
    # settings
    return to_affine_quantized(weight, ...)

model = quantize(model, apply_to_affine_quantized)

# 3.2. dynamic quant with `to_affine_quantized` for both input and weight

from torchao.quantization.quant_api import to_linear_act_quantized
def apply_dynamic_quant_my_dtype(weight):
    # weight settings
    ...

    def input_act_quant_func(x):
        # input settings
        return to_affine_quantized(x, ...)

    weight = to_affine_quantized(weight, ...)
    weight = to_linear_act_quantized(weight, input_act_quant_func)
    return weight

quantize(model, apply_dynamic_quant_my_dtype)

# 3.3. static quant using `my_dtype` for both input and weight
# A working example is work in progress, we'll update this later.

2. autoquantization

autoquant is a tool to automatically quantize the eligible layers with a type of quantization (int8 weight only, int8 dynamic quant, int4 weight only or new dtypes) based on performance for quantizing that individual layer. We’ll have APIs for people to add new dtypes to be searched in the tool.

torchao.autoquant.add_new_dtype("my_new_dtype", to_my_new_dtype)
model = autoquant(model, torch.compile(model, mode='max-autotune'))

Developer API

Developers could be people who are doing research to figure out the best quantization algorithm, or people who we supporting dtype for emerging hardwares.

Prerequisites

We are relying on tensor subclass (and also torch.compile) for our developer facing API, we'll update this section for more OSS available tutorials.

Some externally available resources:

Why Tensor Subclass?

There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things: (1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core (2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization (3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques

Example Code for a new Quantization Technique or DType

Please feel free to start with https://colab.research.google.com/drive/1jqC53MwiW9dSiPS-a6hng_yo0ywdc3nH#scrollTo=Aj9ii4darSRA for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations.

Basic Structure

A tensor subclass needs to define a few basic methods: __new__, __init__, __tensor_flatten__, __tensor_unflatten__ and also dispatch functions for torch functions __torch_function__ and aten ops __torch_dispatch__

Here is an example of basic structure:

from torchao.dtypes.utils import _ATEN_OP_OR_TORCH_FN_TABLE

class MyDTypeLayout(torch.Tensor):
    # see notebook for details
    pass

class MyDtypeTensor(torch.Tensor):
    """We need to define `__new__` for constructing a new tensor subclass instance and `__init__` for initialize
    the instance. There is no requirement on what the argument list should look like here, only requirement is
    that `__new__` must return a Tensor instance with `torch.Tensor._make_wrapper_subclass(cls, shape, ...)` call
    """
    @staticmethod
    def __new__(
        cls,
        layout_tensor: MyDTypeLayout,
        shape: torch.Size,
        dtype: Optional[torch.dtype] = None,
    ):
        ...
        return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)  # type: ignore[attr-defined]

    def __init__(
        self,
        layout_tensor: MyDTypeLayout,
        shape: torch.Size, ...
    ):
        self.layout_tensor = layout_tensor

    """`__tensor_flatten__` and `__tensor_unflatten__` are used to desugar the tensor into native Tensors/attributes and
    reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define
    a Tensor subclass for torch.compile support
    """
    def __tensor_flatten__(self):
        return ["layout_tensor"], [self.shape]

    """see https://github.com/pytorch/pytorch/blob/3bc2004f9123a32f381ef64202252d59109507f3/torch/utils/_python_dispatch.py#L289 for documentations for outer_size and outer_stride
    """
    @classmethod
    def __tensor_unflatten__(
        cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
    ):
       layout_tensor = tensor_data_dict["layout_tensor"]
        shape, = tensor_attributes
        return cls(
            layout_tensor,
            shape if outer_size is None else outer_size,
        )

    """classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype
    """
   @classmethod
    def from_float(
        cls,
        input_float: torch.Tensor,
    ):
        mapping_type = MappingType.SYMMETRIC
        block_size = input_float.shape
        dtype = torch.int16
        scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
        int_data = (input_float / scale).to(torch.int8)
        layout_tensor = MyDTypeLayout.from_plain(int_data, scale)
        return cls(layout_tensor, input_float.shape)

    """[Optional] We can overwrite layout property of the Tensor to represent different packing formats
    """
    @property
    def extended_layout(self) -> str:
        return self.layout_tensor.extended_layout

    """There are two entry points that we can modify the behavior of a pytorch op: torch_function and torch_dispatch:

    __torch_function__: will be called whenever a torch level function is called on the Tensor object, for example: torch.nn.functional.linear,
    tensor.detach, tensor.reshape, tensor.t etc.

    __torch_dispatch__: will be called in the C++ dispatcher, when an aten operator is called on the Tensor object, for example:
    aten.mm, aten.addmm, aten.detach.default, aten.t.default etc.
    """
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        kwargs = {} if kwargs is None else kwargs

        if func in _TORCH_FUNCTIONS_TABLE[cls]:
            return _TORCH_FUNCTIONS_TABLE[cls][func](*args, **kwargs)

        with torch._C.DisableTorchFunctionSubclass():
            return func(*args, **kwargs)

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        if func in _ATEN_OPS_TABLE[cls]:
            return _ATEN_OPS_TABLE[cls][func](func, *args, **kwargs)

        raise NotImplementedError(
            f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
        )

Operator Support

There are two types of operator support, torch function and aten ops. For torch functions (e.g. torch.nn.functional.linear), we’ll need to overwrite __torch_function__ callback in the Tensor subclass, for aten ops (e.g. torch.ops.aten.mm), we’ll need to overwrite __torch_dispatch__ callback function. For a new dtype, we’d like people to define the following decorator:

from torchao.dtypes.utils import _implements

def implements(aten_ops):
    return _implements(my_dtype_tensor_cls, aten_ops_or_torch_fns)

And we can implement the operator dispatch with the following:

# Example for torch_function dispatch for torch.nn.functional.linear
def _quantized_linear_op(input_tensor, weight_tensor, bias):
    if isinstance(input_tensor, MyDtypeTensor):
        input_tensor = input_tensor.dequantize()
    if isinstance(weight_tensor, MyDtypeTensor):
        weight_tensor = weight_tensor.dequantize()
    return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements(torch.nn.functional.linear)
def _(*args, **kwargs):
    input_tensor, weight_tensor, bias = (
        args[0],
        args[1],
        args[2] if len(args) > 2 else None,
    )
    # using try/except here so that we can have a general fallback when input_tensor/weight_tensor
    # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
    # make the branches easier to understand in `_quantized_linear_op`
    try:
        return _quantized_linear_op(input_tensor, weight_tensor, bias)
    except NotImplementedError:
        if isinstance(input_tensor, MyDtypeTensor):
            input_tensor = input_tensor.dequantize()
        if isinstance(weight_tensor, MyDtypeTensor):
            weight_tensor = weight_tensor.dequantize()
        return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

# Example for aten op dispatch for aten.detach.default
@implements([aten.detach.default])
def _(func, *args, **kwargs):
    # `return_and_correct_aliasing` should be used by wrapper tensor ``__torch_dispatch__`` subclasses that would like to 
    # work with torch.compile. It ensures that the subclass properly implements the aliasing behavior of every op, 
    # which is needed for correctness in AOTAutograd.

    # `_apply_fn_to_data` just applies the function to the tensor data in `args[0]`, `args[0]` is a tensor subclass
    # of `my_dtype`
    return return_and_correct_aliasing(
        func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
    )

What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are: torch_function: torch.nn.functional.linear torch_dispatch: torch.ops.aten.addmm.default, torch.ops.aten.mm.default, torch.ops.aten.detach.default, torch.ops.aten.t.default

You can also find the ops that can be overwritten in torch_function or torch_dispatch with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see Optimized Operators section for more details):

class M(torch.nn.Module): 
    def __init__(self) -> None: 
        super().__init__() 
        self.linear = torch.nn.Linear(10, 10)
    def forward(self, x: torch.Tensor) -> torch.Tensor: 
        return self.linear(x) + x

from torch.overrides import TorchFunctionMode
class TorchFunctionLoggingMode(TorchFunctionMode):
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        print(f"TORCH_FUNC={str(func)}")
        return func(*args, **kwargs)

with TorchFunctionLoggingMode():
     m(*example_inputs)

## Example output
# TORCH_FUNC=<built-in function linear>
# TORCH_FUNC=<method 'add' of 'torch._C.TensorBase' objects>

from torch.utils._python_dispatch import TorchDispatchMode
class TorchDispatchLoggingMode(TorchDispatchMode):
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        print(f"ATEN_FUNC={str(func)}")
        return func(*args, **kwargs)

with TorchDispatchLoggingMode():
     m(*example_inputs)

## Example output
# ATEN_FUNC=aten.t.default
# ATEN_FUNC=aten.addmm.default
# ATEN_FUNC=aten.add.Tensor

# or a more polished logging for torch_dispatch (aten) ops: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py

We are still working on a table that talks about for each feature what are the operators that need to be supported.

Optimized Operators

Optimized operators for cpu/cuda/mps can be implemented through https://github.com/pytorch/ao/tree/main/torchao/csrc e.g. int4 cuda, and accessible through torch.ops.my_custom_op

For dispatching to optimized kernels for cpu/cuda/mps devices, we can have checks for the dispatch conditions in torch_function or torch_dispatch and dispatch to target operators, for example: https://github.com/pytorch/ao/blob/cbc74ee6a3dc0bae367db5b03bc58896fffe3ae0/torchao/dtypes/aqt.py#L348-L355.

Packing/Layout

Sometimes the quantized weights has to be packed in order to yield optimal performance. For this we want to extend the “layout” concept in Tensor and introduce an indirection for tensor data storage, see https://github.com/pytorch/ao/pull/278 for more details.

Native tensors have a hardcoded list of selections of layout: https://github.com/pytorch/pytorch/blob/647815049ec28a72dc1bb6a977791927bba058d5/c10/core/Layout.h#L11, most common one is strided layout, it provides a strided, multi-dimensional view of storage, we also have some sparse and mkldnn layout.

The idea of packing the tensor into different formats fits nicely with the layout concept, that’s why we want to reuse this for packing. And the extension of layout can be achieved at python level tensor subclasses without modifying C++ pytorch core code.

Here is an example (see notebook for full code):

# 1. define a base layout for your dtype
class MyDTypeLayout(torch.Tensor):
    """
    Base class for the layout tensor for `MyDTypeTensor`
    """

    # this should be set for each layout class during registration
    extended_layout: Optional[str] = None

    # get the original unpacked Tensors
    def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.int_data, self.scale

    # how to get the layout tensor from plain tensors
    @classmethod
    def from_plain(
        cls,
        int_data: torch.Tensor,
        scale: torch.Tensor,
    ):
    pass

# 2. define a registration method for layout

from torchao.dtypes.utils import _register_layout_cls

def register_layout_cls(extended_layout: str):
    return _register_layout_cls(MyDTypeTensor, extended_layout)

def get_layout_tensor_constructor(extended_layout: str):
    return _get_layout_tensor_constructor(MyDTypeTensor, extended_layout)

# 3. define new layout
@register_layout_cls("plain")
class MyDTypePlainLayout(MyDTypeLayout):
    def __new__(cls, ...):
        pass

    def __init__(self, ...):
     pass

    @classmethod
    def __tensor_flatten__(self):
        pass

    @classmethod
    def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride):
        pass

    @classmethod
    def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
        packed = pack(int_data, scale, zero_point, inner_k_tiles)
     return cls(packed, ...)

# 4. use the layout tensor in original tensor subclass

class MyDtypeTensor(torch.Tensor):
    @classmethod
    def from_float(
        cls,
        input_float: torch.Tensor,
        extended_layout: str = "plain",
    ):
        layout_tensor_ctr = get_layout_tensor_constructor(extended_layout)
        layout_tensor = layout_tensor_ctr(int_data, scale)
        return cls(layout_tensor, input_float.shape)

Flow

After the tensor subclass is implemented, we can also wrap that into factory functions, e.g.

# convert from floating point tensor to affine quantized tensor
to_my_dtype = MyDTypeTensor.from_float

For model level API, people can reuse torchao.quantize that allows people to apply a tensor subclass conversion to weight of linear, and allows filtering function: https://github.com/pytorch/ao/blob/aeee551b15eebeaabf98ffab9a00addc675a12a9/torchao/quantization/quant_api.py (TODO: replace this with torchao doc website link when that's ready)

See Modeling User API section for examples of weight only/dynamic quant/static quant model level APIs based on the factory function.

Using torch.compile for Performance

Note: currently, we need to use the following:

from torchao.quantization.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)

In order to be compatible with torch.compile. To aim for performance optimization, we should run through torch.compile with fullgraph mode first, and remove any unnecessary graph breaks. You can add TORCH_LOGS=”output_code” when you run the script in order to see the inductor generated code. e.g. TORCH_LOGS=”output_code” python example.py

model = torch.compile(model, mode="max-autotune", fullgraph=True)

Serialization

This test shows how we expect save/load to work for a model quantized with tensor subclass based API:

m = ToyLinearModel().eval().to(torch.bfloat16)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16)

m = quantize(m, "int8_weight_only")
ref = m(*example_inputs)
with tempfile.NamedTemporaryFile() as f:
    torch.save(m.state_dict(), f)
    f.seek(0)
    state_dict = torch.load(f)

# we can use the original floating point model, but use assign=True to load a quantized state_dict
m_copy.load_state_dict(state_dict, assign=True)

res = m_copy(*example_inputs)
assert torch.equal(res, ref)

What's Next

gau-nernst commented 2 weeks ago

Regarding 1, apart from what I have feedbacked in #384, starting to think of another alternative

quantizer = Int4WeightOnlyQuantizer(groupsize=32)
quantizer.quantize(model)

But then this feels like the old api change_linear_weights_to_int4_woqtensors(mode, groupsize=32), which we have moved away from. The current quantize() does feel somewhat more convenient.

Personally I don't really like a function returning a function, like the current int4wo and int8wo. Feels like having a proper class makes it cleaner (we can also inspect the quant hyperparams after instantiation) - as discussed in #384.

Another option is to expose apply_int4wo_quant() directly and the user should call partial.functools() on it (same effect as current int4wo() implementation)

from functools import partial

quantize(model, partial(apply_int4wo_quant, groupsize=32))

Also, since the quantization is in-place, I think it's good to use quantize_() instead to clearly signal the in-place behavior.

drisspg commented 2 weeks ago

For the manual API why have both a string and a int4wo(group_size), I think it would be cleaner to just have one version of this

jeromeku commented 2 weeks ago

Is there a tutorial or end-to-end example of how to compose these APIs to implement a non-trivial quantization method (e.g., AWQ, GPTQ, etc.) and specialized deployment layout (e.g., Marlin)? Basically a reference impl of how these tools can be used to facilitate the translation of research ideas to deployment-ready libraries.

If not, happy to work on one.

jerryzh168 commented 2 weeks ago

Regarding 1, apart from what I have feedbacked in #384, starting to think of another alternative

quantizer = Int4WeightOnlyQuantizer(groupsize=32)
quantizer.quantize(model)

But then this feels like the old api change_linear_weights_to_int4_woqtensors(mode, groupsize=32), which we have moved away from. The current quantize() does feel somewhat more convenient.

Personally I don't really like a function returning a function, like the current int4wo and int8wo. Feels like having a proper class makes it cleaner (we can also inspect the quant hyperparams after instantiation) - as discussed in #384.

Another option is to expose apply_int4wo_quant() directly and the user should call partial.functools() on it (same effect as current int4wo() implementation)

from functools import partial

quantize(model, partial(apply_int4wo_quant, groupsize=32))

Also, since the quantization is in-place, I think it's good to use quantize_() instead to clearly signal the in-place behavior.

the quantizer API is actually what I have been thinking about before as "Unified Quantization API": https://github.com/pytorch/ao/blob/main/torchao/quantization/unified.py and these two APIs will cover most of the current quant flows, it's also used by QAT prototype: https://github.com/pytorch/ao/blob/d0af9415a0a0055288be5208e05d4f494efbcfa8/torchao/quantization/prototype/qat.py#L22, personally I think we can use this so we have a unified experience for modeling users. But Christian has raised some concerns on this one since he feels introducing classes is a bit overkill I think.

the partial function idea has been raised in our meetings before as well, but that also doesn't seem very straightforward to use.

For now I'm planning to just use quantize(model, int4_weight_only(groupsize=32)) and but open to change in the future if there are more feedback on this API

also in the ideal future I think we'd expect modeling user just use the autoquant and not worry about all these details

jerryzh168 commented 2 weeks ago

For the manual API why have both a string and a int4wo(group_size), I think it would be cleaner to just have one version of this

so the motivation for string is so that people don't need to import anything to use it, it's just a simple shortcut and we'll make sure to align the names

jerryzh168 commented 2 weeks ago

Is there a tutorial or end-to-end example of how to compose these APIs to implement a non-trivial quantization method (e.g., AWQ, GPTQ, etc.) and specialized deployment layout (e.g., Marlin)? Basically a reference impl of how these tools can be used to facilitate the translation of research ideas to deployment-ready libraries.

If not, happy to work on one.

Not yet, so my understanding is that this doc talks about how we build the fundamental "dtype" of quantization, it can serve as a building block for more sophisticated quantization method that can utilize the "dtype" as a data representation.

I'm planning to put up an example of static quant (with module swap) that could potentially help demonstrate how these other techniques (e.g. ones that require calibration etc.) can be implemented in similar ways. please feel free to work on a tutorial to show how a real world end to end quantization example looks like utilizing the "dtype" that we build with tensor subclass in this doc

we also plan to build out hqq with this design https://github.com/pytorch/ao/issues/255, cc @HDCharles, this one also doesn't not require calibration though. also there is GPTQ that could be refactored to use tensor subclass and compose with AffineQuantizedTensor, the main thing for GPTQ is we are not sure if people are interested in using it, but seems like we have some feedback saying this is important: https://github.com/pytorch/ao/issues/384#issuecomment-2175707059, so maybe we could refactor it as well.

drisspg commented 2 weeks ago

so the motivation for string is so that people don't need to import anything to use it, it's just a simple shortcut and we'll make sure to align the names

But they are already importing the quantize api right? Idk I tend to be in favor of verbosity, but this was a nit anyways so carry on

jerryzh168 commented 2 weeks ago

so the motivation for string is so that people don't need to import anything to use it, it's just a simple shortcut and we'll make sure to align the names

But they are already importing the quantize api right? Idk I tend to be in favor of verbosity, but this was a nit anyways so carry on

yeah, we are thinking of just removing these for now, it would be better for people to also see the docstrings for these things, and an extra import doesn't seem to be a big issue

vadimkantorov commented 1 week ago

About subclasses: I hope there would still be way to (when needed) register custom fused kernels which do e.g. q-someop-dq in a fused way, without having a separate kernel launches for q and dq. I know this type of graph matching is possible with torch.compile, but I hope that the explicit introduction of subclasses (and seemingly mainly used for representational/expressiveness/dispatch purpose) will not make this more complicated.

Also, hoping that it will work nicely with profiling/tracing to know exactly what kernel is getting invoked and exactly where any q/dq is happening (especially for autoquant regimes).

This is kind of similar to what was originally done with quint8 dtype, right? (except now it will allow user-powered extension and dispatch is based on subclass type instead of dtype)

jerryzh168 commented 1 week ago

About subclasses: I hope there would still be way to (when needed) register custom fused kernels which do e.g. q-someop-dq in a fused way, without having a separate kernel launches for q and dq. I know this type of graph matching is possible with torch.compile, but I hope that the explicit introduction of subclasses (and seemingly mainly used for representational/expressiveness/dispatch purpose) will not make this more complicated.

yeah I think we should still be able to register inductor fusion passes, but one thing here is, q/dq ops are no longer large ops in the torch.compile path, we are planning to keep them as smaller aten ops (sub/mul etc.) so these can participate in normal inductor optimization directly, so the optimization story will be a bit different for inductor/torch.compile I think.

However, we are preserving q/dq ops as high level ops for executorch (export path), since the current executorch backends need to work with the patterns like (dq -> fp32 op -> q), this is WIP in https://github.com/pytorch/ao/pull/434

Also, hoping that it will work nicely with profiling/tracing to know exactly what kernel is getting invoked and exactly where any q/dq is happening (especially for autoquant regimes).

yeah we can definitely provide additional information on what kernel is picked for autoquant, cc @HDCharles

This is kind of similar to what was originally done with quint8 dtype, right? (except now it will allow user-powered extension and dispatch is based on subclass type instead of dtype)

yes, this is similar to quint8, except it's built in python with tensor subclasses extension point, this allows us to stay out of core and have faster iteration speed as well. for dispatch, I feel it could also continue to use dtype as well, after we sort out the dtype story: https://github.com/pytorch/ao/issues/442

kimishpatel commented 5 days ago

However, we are preserving q/dq ops as high level ops for executorch (export path), since the current executorch backends need to work with the patterns like (dq -> fp32 op -> q), this is WIP in #434

Based on the example, it seems like it would be the property of DTypeTensor that decides whether to use q-dq or not, right?

kimishpatel commented 5 days ago

So what I understand from this proposal, as far as wrapping LayoutTensor and DTypeTensor is concerned is that,

A. Static quantization (both activation and weights are quantized) B. Dynamic quantization. Weight is quantized AOT, act quantized dynamically C. Weight only quantization.

It is not clear how the proposed API addresses 1, but I presume you have ideas so I will assume it will work.

Tensor subclass as I understand does/can do two things: 1) override representation of the tensor, e.g. linear.weight changed from torch.Tensor to DTypeTensor and 2) also change the dispatch behavior to dictate how an op with DTypeTensor should be executed. DType tensor seem to be well suited for 1, but 2, that dictates execution semantics of an op with DTypeTensor in its args, has conflict with B and C. What I mean by that is that a 4-bit DTypeTensor, with whatever layout, can do both B and C. If so what would be the right design. Should we introduce yet another tensor subclass like WeightOnlyQuantizedTensor(DTypeTensor) And have DynamicQuantWeightTensor that will dynamically quantized activation tensor? OR add more args to DTypeTensor e.g. DTypeTensor.quant_type : Enum('dynamic', 'weight_only', 'static')? Given there arent many varieties in between static and dynamic act quantization, I would be ok if we "suggest" arg based approach.

On the DTypeLayout: I feel that having each backend or kernel that has its own special layout for execution should be its own tensor subclass, however this can also result in proliferation, e.g. DTypeLayoutCUDA, DTypeLayoutCUDAMySecialPacking, DTypeLayoutMetalDefault etc. I actually liked PT2E workflow in this regard where representation was canonical and execution semantics, arising from weight packing etc, were done as a separate transform. If I were to think of the same here, then I would say for 4-bit there is DTypeTensor and DTypeDefaultLayout and subsequent transforms can replace the tensor subclass with their backend specific tensor subclass.

Separate from above: For the comment on using q-dq based dispatch vs. fused op, I think we can allow overriding behavior where users can plugin their own implementation, including custom fused ops, for a specific DTypeTensor subclass that uses a specific DTypeLayout tensor.

jerryzh168 commented 4 days ago

However, we are preserving q/dq ops as high level ops for executorch (export path), since the current executorch backends need to work with the patterns like (dq -> fp32 op -> q), this is WIP in #434

Based on the example, it seems like it would be the property of DTypeTensor that decides whether to use q-dq or not, right?

yeah this is correct

static quantization

yeah working on an example for this right now

dynamic quantization

I should probably add more docs for this one, right now it's implemented by applying a LinearActQuantizedTensor (which stores a input_quant_func and the original weight) on top of a Affine quantized tensor: https://github.com/pytorch/ao/blob/a8956992191853b13f82ceb3e6929bed7691a3fa/torchao/quantization/quant_api.py#L355-L356, in LienarActQuantizedTensor, when dispatching to linear op, we'll apply the quantization function to input_quant_func to the input, and then continue the dispatch: https://github.com/pytorch/ao/blob/a8956992191853b13f82ceb3e6929bed7691a3fa/torchao/quantization/subclass.py#L657, and in AffineQuantizedTensor dispatch, it's dispatched based on the type of input and weight, this is not distinguishable from the final dispatch of static quant I think: https://github.com/pytorch/ao/blob/a8956992191853b13f82ceb3e6929bed7691a3fa/torchao/dtypes/affine_quantized_tensor.py#L550-L554

also I want to highlight that dynamic quant, static quant is not considered as purely a dtype problem, since this also involves flows (how to convert my model to use these quantized tensors?), I'm also working on giving more details/examples of how to do that as well.

DTypeLayout

  1. I feel we could still produce a canonical representation for executorch, e.g. we can introduce a "canonical" (name TBD) layout that will use q/dq etc. without any optimizations and rely on backend lowering to do weight packing
  2. for " DTypeDefaultLayout and subsequent transforms can replace the tensor subclass with their backend specific tensor subclass." yes, I think this should be implemented under this API: `default_layout_tensor.to(extended_layout="my_optimized_packing_format") right now, we haven't implemented this part, but that's what we can do following current design

For the comment on using q-dq based dispatch vs. fused op, I think we can allow overriding behavior where users can plugin their own implementation, including custom fused ops,

yeah I think so, user should be able to customize what they would like to say by implementing a new LayoutTensor type I think, although I guess the difference here is user has to reason through different dispatch layers to figure out what is the final representation they will see in the end, like the dynamic quant example.

kimishpatel commented 4 days ago

I feel we could still produce a canonical representation for executorch, e.g. we can introduce a "canonical" (name TBD) layout that will use q/dq etc. without any optimizations and rely on backend lowering to do weight packing

@jerryzh168 please note that my questions/responses are not motivated by whether it works for executorch or not. My comment on canonical representation was to borrow the same concept from PT2E where quantization and execution of quantized ops are separated. In the current APIs proposed, it is not the case and thats what I was highlighting

kimishpatel commented 4 days ago

I feel we could still produce a canonical representation for executorch, e.g. we can introduce a "canonical" (name TBD) layout that will use q/dq etc. without any optimizations and rely on backend lowering to do weight packing

@jerryzh168 please note that my questions/responses are not motivated by whether it works for executorch or not. My comment on canonical representation was to borrow the same concept from PT2E where quantization and execution of quantized ops are separated. In the current APIs proposed, it is not the case and thats what I was highlighting

And this I mean for eager model not for export. Basically in exported graph there is a) quant and b) lowering. What is the equivalent of that in eager mode subclass based API and whether it is useful to have that