Open msaroufim opened 2 weeks ago
makes sense, it feels like https://github.com/jax-ml/ml_dtypes is closer to a base/fundamental dtypes library, where everything makes sense by themselves, I think we can define tensor subclasses for these and then on top of them we can define derived dtypes like float8/mx/aqt_intx/aqt_floatx as well
Looking at quantize_(m, int4_weight_only())
, could we have a torchao.to(model.weight, dtype=torchao.dtypes.aq4([...]))
?
The issue is that types such as aqt
aren't really individual types such as bfloat16 but really type factories or families of types. aqt_intx means affine quantized that means it's any type that can be expressed as the affine transform of a set of integers (with scales and offsets (i.e. a = mx + b)). But there are parameters such as a bit-width of these integers and whether scales or offsets are present, etc.
int4_weight_only
encoded a lot of these choices or makes them for the user. A dtype library tasks the user with instantiating these types explicitly and then converting to them. A dtype library can be used to implement a higher level quantize_(m, int4_weight_only())
or autoquant
.
one issue that this runs into is that our tensor subclasses are hardcoded in pytorch to have several pieces of metadata that need to be maintained in order for everything to work. If you replace a MxN fp16 weight with a tensor subclass, you need the tensor subclass to have the same shape and dtype.
So the subclass has a notion of dtype already and now when you call .to(int4_subclass_dtype) on a bf16 weight or whatever, if you check a.dtype, you expect it to give you int4_subclass_dtype but it outputs torch.bfloat16 which will be confusing.
@HDCharles - Yes, I think we'd need to extend the dtype capabilities of PyTorch itself and it's not enough to "emulate" a dtype to other systems.
Right now as I go through the docs it's clear that AO is a quantization and sparsity library but dtypes aren't really a first class citizen and a dtype library was closer in vision to what @cpuhrsch had pitched me when we first started working on the project
For example we do
quantize_(m, int4_weight_only())
and notm.to(int4)
We've opted to do this because supporting extra dtypes with optional arguments like scale to the PyTorch
.to()
function was hard and the need wasn't really there but I would argue we could make our APIs look more like.to()
without making any changes to PyTorchI was recently made aware of https://github.com/jax-ml/ml_dtypes and that very much looks and feels like a dtype library. For example
And what's interesting about the above snippet is you import a dtype and then that dtype can be registered with either jax or numpy and the printed dtype is sensible. There are no special constructors either, bfloat16 feels like a native numpy dtype
The other cool thing about the library is they have a very clear specification for what the dtypes mean and @vkuzo already has a nice script he used to do this for the mx formats https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats#floating-point-format-convenience-functions but we don't do this notably for our int dtypes.
So feels like we could do a few things to improve our perception as a dtype library
torchao.to()
. I'm not clear on the right way to do this, part of me is also thinking about monkey patching the PyTorch.to()
whentorchao()
gets imported so we can override it with extra arguments like scaling factors but opening this proposal in case people have ideas on the right way to do this.to()
means in PyTorch but we should really look into this as a last resort