Open msaroufim opened 1 month ago
At some point we probably need to port float_quantize()
from https://github.com/Tiiiger/QPyTorch (FP6-LLM use that to do FP16->FP6). The main logic is to handle correct rounding (we cannot just "erase" unwanted bits).
The code for that is actually quite simple (https://github.com/Tiiiger/QPyTorch/blob/f58bba72113e696099ef3e15e06cf421a06ff289/qtorch/quant/quant_cpu/quant_cpu.cpp#L267-L300). I tried out locally and it seems we can implement that in pure PyTorch, and potentially torch.compile()
will make efficient CPU/GPU kernels for it (everything is elementwise op).
float2half()
and __internal_float2half()
in cuda_fp16.hpp
). Probably can be an inspiration to write our own FP32/16 -> FPx (together with qtorch).(NOTE: float_quantize()
from qtorch does not handle bit-packing. The output is in original dtype, so it's like fake quantization)
(PyTorch now has uint32 I think, but still no ops, but maybe needed bit ops can just be enabled on uint32 (and other unsigned dtypes))
@vkuzo has an excellent FP32<->FPx conversion functions for MX dtypes (only for FP4_E2M1, FP6_E3M2, FP6_E2M3, but should be straight-forward to extend to other custom FPx). For quantization-aware training/fine-tuning, this should be enough as we probably don't need bit-packing. For inference, @vayuda is working on generic bit-packing, which should be useful (although personally I would prefer explicit bit-packing for each FPx dtype). Not sure if the current state of torch compiler can fuse FPx dequant into triton matmul kernel. If not, we probably need to write custom fused dequant+matmul triton kernel (which can be fun to do 😃)
Prototype FP8 quant using native PyTorch fp8 dtypes: https://github.com/gau-nernst/ao/blob/fp8wo/torchao/prototype/fp8wo/__init__.py
Llama2-7B-chat on 4070Ti SUPER. tokens/s is measured with torchao/_models/llama/generate.py
. PyTorch 2.4.0.dev20240610+cu124
Quant type (weight only) | token/s | Bandwidth GB/s |
---|---|---|
BF16 | 46.74 | 617.64 |
INT8 | 88.22 | 584.05 |
FP8 E4M3 FN | 79.33 | 1048.31 |
FP8 E4M3 FNUZ | 78.73 (output is gibberish) | 1040.44 |
FP8 E5M2 | 82.53 | 1090.65 |
FP8 E5M2 FNUZ | Error (see below) |
Error with FP8 E5M2 FNUZ: Unsupported conversion from f16 to f8E5M2FNUZ with rounding mode rtne
The speed degradation compared to INT8 seems like because torch.compile cannot fuse act_bf16 @ weight_fp8.to(torch.bfloat16) * scales
into a single kernel (thus memory bandwidth is also very high).
import torch
import torch._inductor.config
# core dump if either of these flags are enabled
# torch._inductor.config.force_mixed_mm = True
# torch._inductor.config.use_mixed_mm = True
def f(a, b, s):
return torch.mm(a, b.to(a.dtype)) * s
fp16_act = torch.randn(1, 32).to(torch.bfloat16).cuda()
fp8_weight = torch.randn(32, 32).to(torch.float8_e5m2).cuda()
scales = torch.randn(32).to(torch.bfloat16).cuda()
torch.compile(f, mode="max-autotune", fullgraph=True)(fp16_act, fp8_weight, scales)
Codegen output: https://gist.github.com/gau-nernst/4afd0f5b97368ecf26d54b5f3415b004
When either use_mixed_mm
or force_mixed_mm
flag is set, I got core dump. (already opened an issue at PyTorch core https://github.com/pytorch/pytorch/issues/128381)
loc("/tmp/torchinductor_thien/h3/ch3inellku5joxa2jz4iwhfnrcquf7fbmuq53uw4vr6kuoctvtzo.py":76:21): error: size mismatch when packing elements for LLVM struct expected 4 but got 8
python: /root/.triton/llvm/llvm-6f44bb77-centos-x64/include/llvm/ADT/ArrayRef.h:257: const T& llvm::ArrayRef<T>::operator[](size_t) const [with T = mlir::Type; size_t = long unsigned int]: Assertion `Index < Length && "Invalid index!"' failed.
One reason why FP8 quant is not so performant is probably because there is no optimized FP8->BF16 dtype conversion that can be fused with triton (either at triton level or torch.compile level, need to investigate further...)
As I was reviewing https://github.com/pytorch/ao/pull/223
I was reminded of this PR https://github.com/pytorch/ao/pull/214
And I'd be curious what range of floating point numbers we can just express using sublcasses