huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.47k stars 27.11k forks source link

model.dequantize() does not remove quantization_config from model.config which effects in error when loading the model #34847

Closed konradkalita closed 3 days ago

konradkalita commented 4 days ago

System Info

Who can help?

quantization: @SunMarc @MekkCyber

Information

Tasks

Reproduction

from transformers import AutoModel, AutoTokenizer
from transformers import BitsAndBytesConfig

bnb_config = {
    "bnb_4bit_compute_dtype": "bfloat16",
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_use_double_quant": True,
    "load_in_4bit": True
}

model = AutoModel.from_pretrained("mistralai/Mistral-7B-v0.1", quantization_config=bnb_config)
model = model.dequantize()
model.save_pretrained("mistral_test_dequantized")
model = AutoModel.from_pretrained("mistral_test_dequantized")

Stacktrace:

Traceback (most recent call last):
  File "./test_model.py", line 40, in <module>
    main()
  File "./test_model.py", line 37, in main
    model = AutoModel.from_pretrained("mistral_test_dequantized")
  File "./venv/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
  File "./venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4226, in from_pretrained
    ) = cls._load_pretrained_model(
  File "./venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4729, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "./venv/lib/python3.10/site-packages/transformers/modeling_utils.py", line 994, in _load_state_dict_into_meta_model
    hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
  File "./venv/lib/python3.10/site-packages/transformers/quantizers/quantizer_bnb_4bit.py", line 207, in create_quantized_param
    raise ValueError(
ValueError: Supplied state dict for layers.0.mlp.down_proj.weight does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.

Expected behavior

Dequantized model can be saved using save_pretrained and properly loaded using from_pretrained

SunMarc commented 4 days ago

Thanks for the report. This happens because the config still have the quantization_config. Would you like to open a PR to fix it ? I think all you need to do is to add in the dequantize() method the following: del self.config.quantization_config