pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.63k stars 181 forks source link

The next tutorials #426

Open msaroufim opened 5 months ago

msaroufim commented 5 months ago

From our README.md

torchao is a library to create and integrate high-performance custom data types layouts into your PyTorch workflows

And so far we've done a good job building out the primitive data types along with their corresponding transformed Linear Layers so for example given a new ExoticDtype() we have a playbook to create ExoticDtypeLinear() and indeed for weight only transformations this is a perfectly fine workflow and how the majority of quantization libraries operate.

For example

m = DownloadModelFromHuggingFace()
quantize_(m, int4_weight_only()) # This will swap out all torch.nn.Linear with a 4 bit Linear

We can make the above shine with more accessible blogs and performance benchmarks and integrations with more partners

However, this is doing somewhat of a disservice at explaining the ao value proposition. For example, we're a dtype library and not a dtype Linear library so given a dtype it should be easy for us to do a lot more. So some examples I'd like to see next are

None of the above is "research", this is very much the way engineering is moving for inference https://blog.character.ai/optimizing-ai-inference-at-character-ai/

Also given an exotic quantization schema I'd like to be more proactive in helping people benchmark their models so this should include

gau-nernst commented 5 months ago

8-bit Adam from bitsandbytes. Resources for reference

jeromeku commented 5 months ago

@msaroufim

Would love to work on this.

msaroufim commented 5 months ago
jeromeku commented 5 months ago

@msaroufim

RE: profiling

msaroufim commented 5 months ago

For metrics the most important ones are memory bandwidth and flop utilization. A good representative workload for now is probably llama2 and llama3 https://github.com/pytorch/ao/blob/main/torchao/_models/llama/generate.py and this script has good metric instrumentation already so extending it feels natural

And for specific algorithms to test out I'd be most curious about testing out