Open t-vi opened 2 months ago
@t-vi Should we reshape meta into two dimension tensor with torch.uint8 like GPU result? Should we execute PyTorch's 8-bit quantization for CPU?
Should we implement dedicated QuantState class for meta and cpu to return tensor along with its corresponding quantization state as for gpu?
Hi @tombawor , thank you for your interest.
I don't think we need a new class, just functions to complement bitsandbytes.functional.quantize_4bit(w, quant_type="nf4")
for meta and cpu inputs (to return a tensor on w.device and a quant state with tensors in w.device).
Ideally, the quantize_weight function should have exactly the same inputs and outputs, except that all tensors stay on the device they are at, so same shape and quant state as if we called bitsandbytes.functional.quantize_4bit
.
We could also offer it to bitsandbytes if they're interested.
There’s a multi-backend effort under way which is currently in alpha release for bitsandbytes. This is cpu implementation from bitsandbytes.
Currently the BitsAndBytesLinearQuant4bit for submodule always calls
bitsandbytes.functional.quantize_4bit
. This is somewhat touchy for CPU tensors becausequantize_4bit
only works on GPU tensors but it is outright not so nice for meta tensors, where we only would need to get the right shapes.https://github.com/Lightning-AI/lightning-thunder/blob/e64d347def39bb47101efafe4177adf9f77a63ec/thunder/transforms/quantization.py#L93-L103