bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.12k stars 611 forks source link

Loading directly 4bit quantized model #1123

Closed ByungKwanLee closed 6 months ago

ByungKwanLee commented 6 months ago

System Info

I saved 4bit quantized model

Then, how to load 4bit quantized model directly with 'from_pretrained' ??

It is normal to save Large Models with float16 or float32 or bfloat16.

But in my case, I saved 4bit directly and want to load 4bit quantized model.

Reproduction

save 4 bit quantized model with save_pretrained() load from_pretrained() with 4bit quantization config

Expected behavior

Uint8 Value Directly Loaded

younesbelkada commented 6 months ago

Hi @ByungKwanLee Yes this should work properly, are you getting an error?

thepowerfuldeez commented 2 months ago

Hi! HF documentation says that we currently cannot load 4bit weights after we saved them with save_pretrained() https://huggingface.co/docs/transformers/v4.35.1/en/main_classes/quantization#load-a-large-model-in-4bit

thepowerfuldeez commented 2 months ago

When using with FSDP, this currently results in the following error:

[rank5]: Traceback (most recent call last):
[rank5]:   File "/app/train.py", line 382, in <module>
[rank5]:     main(script_args, training_args)
[rank5]:   File "/app/train.py", line 306, in main
[rank5]:     trainer.train(resume_from_checkpoint=checkpoint)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1932, in train
[rank5]:     return inner_training_loop(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2079, in _inner_training_loop
[rank5]:     self.model = self.accelerator.prepare(self.model)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1304, in prepare
[rank5]:     result = tuple(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1305, in <genexpr>
[rank5]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1181, in _prepare_one
[rank5]:     return self.prepare_model(obj, device_placement=device_placement)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1477, in prepare_model
[rank5]:     model = FSDP(model, **kwargs)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 469, in __init__
[rank5]:     _auto_wrap(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank5]:     _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank5]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank5]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank5]:     wrapped_child, num_wrapped_params = _recursive_wrap(
[rank5]:   [Previous line repeated 2 more times]
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank5]:     return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank5]:     return wrapper_cls(module, **kwargs)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 495, in __init__
[rank5]:     _init_param_handle_from_module(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py", line 599, in _init_param_handle_from_module
[rank5]:     _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py", line 611, in _init_param_handle_from_params
[rank5]:     handle = FlatParamHandle(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank5]:     self._init_flat_param_and_metadata(
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank5]:     ) = self._validate_tensors_to_flatten(params)
[rank5]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 768, in _validate_tensors_to_flatten
[rank5]:     raise ValueError("Cannot flatten integer dtype tensors")
[rank5]: ValueError: Cannot flatten integer dtype tensors

However loading from bf16 weights works fine. bnb_4bit_compute_dtype and bnb_4bit_quant_storage are set to torch.bfloat16 as needed for FSDP

thepowerfuldeez commented 2 months ago

Also I thought the problem might be that I load with device_map='auto' and HF docs suggest to move to cpu before saving. However this is not working as well:

ValueError: `.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.