pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.48k stars 149 forks source link

AO can also be a dtype library #1052

Open msaroufim opened 2 weeks ago

msaroufim commented 2 weeks ago

Right now as I go through the docs it's clear that AO is a quantization and sparsity library but dtypes aren't really a first class citizen and a dtype library was closer in vision to what @cpuhrsch had pitched me when we first started working on the project

For example we do quantize_(m, int4_weight_only()) and not m.to(int4)

We've opted to do this because supporting extra dtypes with optional arguments like scale to the PyTorch .to() function was hard and the need wasn't really there but I would argue we could make our APIs look more like .to() without making any changes to PyTorch

I was recently made aware of https://github.com/jax-ml/ml_dtypes and that very much looks and feels like a dtype library. For example

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> np.zeros(4, dtype=bfloat16)
array([0, 0, 0, 0], dtype=bfloat16)

And what's interesting about the above snippet is you import a dtype and then that dtype can be registered with either jax or numpy and the printed dtype is sensible. There are no special constructors either, bfloat16 feels like a native numpy dtype

The other cool thing about the library is they have a very clear specification for what the dtypes mean and @vkuzo already has a nice script he used to do this for the mx formats https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats#floating-point-format-convenience-functions but we don't do this notably for our int dtypes.

So feels like we could do a few things to improve our perception as a dtype library

  1. Print the output of the mx spec and check that in the repo in some streamlined way
  2. For intX and floatX make it clearer whether these are padded or bitpacked representation. Specifically for int4 TensorCoreLayout was never particularly clear to me
  3. We should think about implementing a torchao.to(). I'm not clear on the right way to do this, part of me is also thinking about monkey patching the PyTorch .to() when torchao() gets imported so we can override it with extra arguments like scaling factors but opening this proposal in case people have ideas on the right way to do this
  4. We should experiment with 3 for a while and after that maybe we can revamp what .to() means in PyTorch but we should really look into this as a last resort
jerryzh168 commented 2 weeks ago

makes sense, it feels like https://github.com/jax-ml/ml_dtypes is closer to a base/fundamental dtypes library, where everything makes sense by themselves, I think we can define tensor subclasses for these and then on top of them we can define derived dtypes like float8/mx/aqt_intx/aqt_floatx as well

cpuhrsch commented 2 weeks ago

Looking at quantize_(m, int4_weight_only()), could we have a torchao.to(model.weight, dtype=torchao.dtypes.aq4([...]))?

The issue is that types such as aqt aren't really individual types such as bfloat16 but really type factories or families of types. aqt_intx means affine quantized that means it's any type that can be expressed as the affine transform of a set of integers (with scales and offsets (i.e. a = mx + b)). But there are parameters such as a bit-width of these integers and whether scales or offsets are present, etc.

int4_weight_only encoded a lot of these choices or makes them for the user. A dtype library tasks the user with instantiating these types explicitly and then converting to them. A dtype library can be used to implement a higher level quantize_(m, int4_weight_only()) or autoquant.

HDCharles commented 1 week ago

one issue that this runs into is that our tensor subclasses are hardcoded in pytorch to have several pieces of metadata that need to be maintained in order for everything to work. If you replace a MxN fp16 weight with a tensor subclass, you need the tensor subclass to have the same shape and dtype.

So the subclass has a notion of dtype already and now when you call .to(int4_subclass_dtype) on a bf16 weight or whatever, if you check a.dtype, you expect it to give you int4_subclass_dtype but it outputs torch.bfloat16 which will be confusing.

cpuhrsch commented 1 week ago

@HDCharles - Yes, I think we'd need to extend the dtype capabilities of PyTorch itself and it's not enough to "emulate" a dtype to other systems.