bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.14k stars 616 forks source link

4bit quantized model.dequantize() fails on CPU #1311

Open npbool opened 2 months ago

npbool commented 2 months ago

System Info

ubuntu22.04, python3.10.4, intel cpu bitsandbytes==0.43.3 transformers==4.43.3

Reproduction

quantization_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-7B-Instruct",
    torch_dtype=torch.float16,
    device_map="cpu",
    quantization_config=quantization_config,    
)
base_model.dequantize()

Error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[24], line 1
----> 1 base_model.dequantize()

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:1394, in PreTrainedModel.dequantize(self)
   1391 if hf_quantizer is None:
   1392     raise ValueError("You need to first quantize your model in order to dequantize it")
-> 1394 return hf_quantizer.dequantize(self)

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/quantizers/base.py:202, in HfQuantizer.dequantize(self, model)
    197 def dequantize(self, model):
    198     """
    199     Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance.
    200     Note not all quantization schemes support this.
    201     """
--> 202     model = self._dequantize(model)
    204     # Delete quantizer and quantization config
    205     del model.hf_quantizer

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/quantizers/quantizer_bnb_4bit.py:320, in Bnb4BitHfQuantizer._dequantize(self, model)
    317 def _dequantize(self, model):
    318     from ..integrations import dequantize_and_replace
--> 320     model = dequantize_and_replace(
    321         model, self.modules_to_not_convert, quantization_config=self.quantization_config
    322     )
    323     return model

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:458, in dequantize_and_replace(model, modules_to_not_convert, quantization_config)
    453 def dequantize_and_replace(
    454     model,
    455     modules_to_not_convert=None,
    456     quantization_config=None,
    457 ):
--> 458     model, has_been_replaced = _dequantize_and_replace(
    459         model,
    460         modules_to_not_convert=modules_to_not_convert,
    461         quantization_config=quantization_config,
    462     )
    464     if not has_been_replaced:
    465         logger.warning(
    466             "For some reason the model has not been properly dequantized. You might see unexpected behavior."
    467         )

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:441, in _dequantize_and_replace(model, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced)
    439         model._modules[name] = new_module
    440 if len(list(module.children())) > 0:
--> 441     _, has_been_replaced = _dequantize_and_replace(
    442         module,
    443         modules_to_not_convert,
    444         current_key_name,
    445         quantization_config,
    446         has_been_replaced=has_been_replaced,
    447     )
    448 # Remove the last key for recursion
    449 current_key_name.pop(-1)

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:441, in _dequantize_and_replace(model, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced)
    439         model._modules[name] = new_module
    440 if len(list(module.children())) > 0:
--> 441     _, has_been_replaced = _dequantize_and_replace(
    442         module,
    443         modules_to_not_convert,
    444         current_key_name,
    445         quantization_config,
    446         has_been_replaced=has_been_replaced,
    447     )
    448 # Remove the last key for recursion
    449 current_key_name.pop(-1)

    [... skipping similar frames: _dequantize_and_replace at line 441 (1 times)]

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:441, in _dequantize_and_replace(model, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced)
    439         model._modules[name] = new_module
    440 if len(list(module.children())) > 0:
--> 441     _, has_been_replaced = _dequantize_and_replace(
    442         module,
    443         modules_to_not_convert,
    444         current_key_name,
    445         quantization_config,
    446         has_been_replaced=has_been_replaced,
    447     )
    448 # Remove the last key for recursion
    449 current_key_name.pop(-1)

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:425, in _dequantize_and_replace(model, modules_to_not_convert, current_key_name, quantization_config, has_been_replaced)
    422 else:
    423     state = None
--> 425 new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
    427 if bias is not None:
    428     new_module.bias = bias

File ~/projects/ml/venv/lib/python3.10/site-packages/transformers/integrations/bitsandbytes.py:349, in dequantize_bnb_weight(weight, state)
    346     return weight
    348 if cls_name == "Params4bit":
--> 349     output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
    350     logger.warning_once(
    351         f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
    352     )
    353     return output_tensor

File ~/projects/ml/venv/lib/python3.10/site-packages/bitsandbytes/functional.py:1333, in dequantize_4bit(A, quant_state, absmax, out, blocksize, quant_type)
   1330     raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.")
   1332 if quant_state is None:
-> 1333     assert absmax is not None and out is not None
   1335     quant_state = QuantState(
   1336         absmax=absmax,
   1337         shape=out.shape,
   (...)
   1340         quant_type=quant_type,
   1341     )
   1343 else:

AssertionError: 

Expected behavior

This code should work fine on cpu as on nvidia gpu.