Closed ByungKwanLee closed 6 months ago
Hi @ByungKwanLee Yes this should work properly, are you getting an error?
Hi! HF documentation says that we currently cannot load 4bit weights after we saved them with save_pretrained() https://huggingface.co/docs/transformers/v4.35.1/en/main_classes/quantization#load-a-large-model-in-4bit
When using with FSDP, this currently results in the following error:
[rank5]: Traceback (most recent call last):
[rank5]: File "/app/train.py", line 382, in <module>
[rank5]: main(script_args, training_args)
[rank5]: File "/app/train.py", line 306, in main
[rank5]: trainer.train(resume_from_checkpoint=checkpoint)
[rank5]: File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1932, in train
[rank5]: return inner_training_loop(
[rank5]: File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2079, in _inner_training_loop
[rank5]: self.model = self.accelerator.prepare(self.model)
[rank5]: File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1304, in prepare
[rank5]: result = tuple(
[rank5]: File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1305, in <genexpr>
[rank5]: self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank5]: File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1181, in _prepare_one
[rank5]: return self.prepare_model(obj, device_placement=device_placement)
[rank5]: File "/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py", line 1477, in prepare_model
[rank5]: model = FSDP(model, **kwargs)
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 469, in __init__
[rank5]: _auto_wrap(
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_wrap_utils.py", line 101, in _auto_wrap
[rank5]: _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type]
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank5]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank5]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 543, in _recursive_wrap
[rank5]: wrapped_child, num_wrapped_params = _recursive_wrap(
[rank5]: [Previous line repeated 2 more times]
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 561, in _recursive_wrap
[rank5]: return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/wrap.py", line 490, in _wrap
[rank5]: return wrapper_cls(module, **kwargs)
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 495, in __init__
[rank5]: _init_param_handle_from_module(
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py", line 599, in _init_param_handle_from_module
[rank5]: _init_param_handle_from_params(state, managed_params, fully_sharded_module)
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_init_utils.py", line 611, in _init_param_handle_from_params
[rank5]: handle = FlatParamHandle(
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 582, in __init__
[rank5]: self._init_flat_param_and_metadata(
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 632, in _init_flat_param_and_metadata
[rank5]: ) = self._validate_tensors_to_flatten(params)
[rank5]: File "/usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/_flat_param.py", line 768, in _validate_tensors_to_flatten
[rank5]: raise ValueError("Cannot flatten integer dtype tensors")
[rank5]: ValueError: Cannot flatten integer dtype tensors
However loading from bf16 weights works fine. bnb_4bit_compute_dtype
and bnb_4bit_quant_storage
are set to torch.bfloat16 as needed for FSDP
Also I thought the problem might be that I load with device_map='auto' and HF docs suggest to move to cpu before saving. However this is not working as well:
ValueError: `.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
System Info
I saved 4bit quantized model
Then, how to load 4bit quantized model directly with 'from_pretrained' ??
It is normal to save Large Models with float16 or float32 or bfloat16.
But in my case, I saved 4bit directly and want to load 4bit quantized model.
Reproduction
save 4 bit quantized model with save_pretrained() load from_pretrained() with 4bit quantization config
Expected behavior
Uint8 Value Directly Loaded