pytorch / ao

Native PyTorch library for quantization and sparsity
https://pytorch.org/ao
BSD 3-Clause "New" or "Revised" License
309 stars 49 forks source link

FloatQuantization subclass #228

Open msaroufim opened 1 month ago

msaroufim commented 1 month ago

As I was reviewing https://github.com/pytorch/ao/pull/223

I was reminded of this PR https://github.com/pytorch/ao/pull/214

And I'd be curious what range of floating point numbers we can just express using sublcasses

gau-nernst commented 1 month ago

At some point we probably need to port float_quantize() from https://github.com/Tiiiger/QPyTorch (FP6-LLM use that to do FP16->FP6). The main logic is to handle correct rounding (we cannot just "erase" unwanted bits).

The code for that is actually quite simple (https://github.com/Tiiiger/QPyTorch/blob/f58bba72113e696099ef3e15e06cf421a06ff289/qtorch/quant/quant_cpu/quant_cpu.cpp#L267-L300). I tried out locally and it seems we can implement that in pure PyTorch, and potentially torch.compile() will make efficient CPU/GPU kernels for it (everything is elementwise op).

(NOTE: float_quantize() from qtorch does not handle bit-packing. The output is in original dtype, so it's like fake quantization)

vadimkantorov commented 1 week ago

(PyTorch now has uint32 I think, but still no ops, but maybe needed bit ops can just be enabled on uint32 (and other unsigned dtypes))

gau-nernst commented 6 days ago

@vkuzo has an excellent FP32<->FPx conversion functions for MX dtypes (only for FP4_E2M1, FP6_E3M2, FP6_E2M3, but should be straight-forward to extend to other custom FPx). For quantization-aware training/fine-tuning, this should be enough as we probably don't need bit-packing. For inference, @vayuda is working on generic bit-packing, which should be useful (although personally I would prefer explicit bit-packing for each FPx dtype). Not sure if the current state of torch compiler can fuse FPx dequant into triton matmul kernel. If not, we probably need to write custom fused dequant+matmul triton kernel (which can be fun to do 😃)

gau-nernst commented 2 days ago

Prototype FP8 quant using native PyTorch fp8 dtypes: https://github.com/gau-nernst/ao/blob/fp8wo/torchao/prototype/fp8wo/__init__.py

Llama2-7B-chat on 4070Ti SUPER. tokens/s is measured with torchao/_models/llama/generate.py. PyTorch 2.4.0.dev20240610+cu124

Quant type (weight only) token/s Bandwidth GB/s
BF16 46.74 617.64
INT8 88.22 584.05
FP8 E4M3 FN 79.33 1048.31
FP8 E4M3 FNUZ 78.73 (output is gibberish) 1040.44
FP8 E5M2 82.53 1090.65
FP8 E5M2 FNUZ Error (see below)

Error with FP8 E5M2 FNUZ: Unsupported conversion from f16 to f8E5M2FNUZ with rounding mode rtne

The speed degradation compared to INT8 seems like because torch.compile cannot fuse act_bf16 @ weight_fp8.to(torch.bfloat16) * scales into a single kernel (thus memory bandwidth is also very high).

import torch
import torch._inductor.config

# core dump if either of these flags are enabled
# torch._inductor.config.force_mixed_mm = True
# torch._inductor.config.use_mixed_mm = True

def f(a, b, s):
    return torch.mm(a, b.to(a.dtype)) * s

fp16_act = torch.randn(1, 32).to(torch.bfloat16).cuda()
fp8_weight = torch.randn(32, 32).to(torch.float8_e5m2).cuda()
scales = torch.randn(32).to(torch.bfloat16).cuda()
torch.compile(f, mode="max-autotune", fullgraph=True)(fp16_act, fp8_weight, scales)

Codegen output: https://gist.github.com/gau-nernst/4afd0f5b97368ecf26d54b5f3415b004

When either use_mixed_mm or force_mixed_mm flag is set, I got core dump. (already opened an issue at PyTorch core https://github.com/pytorch/pytorch/issues/128381)

loc("/tmp/torchinductor_thien/h3/ch3inellku5joxa2jz4iwhfnrcquf7fbmuq53uw4vr6kuoctvtzo.py":76:21): error:  size mismatch when packing elements for LLVM struct expected 4 but got 8
python: /root/.triton/llvm/llvm-6f44bb77-centos-x64/include/llvm/ADT/ArrayRef.h:257: const T& llvm::ArrayRef<T>::operator[](size_t) const [with T = mlir::Type; size_t = long unsigned int]: Assertion `Index < Length && "Invalid index!"' failed.

One reason why FP8 quant is not so performant is probably because there is no optimized FP8->BF16 dtype conversion that can be fused with triton (either at triton level or torch.compile level, need to investigate further...)