Closed matthewdouglas closed 2 months ago
I was just talking to @matthewdouglas about this via PM. I think this probably still needs another iteration. My understanding is that the quant_storage
dtype is actually supported for serialization in BNB, so we gotta take this into account:
In [1]: import torch
...: import bitsandbytes as bnb
...: import io
...:
...: def save_and_load_model(model):
...: buffer = io.BytesIO()
...: torch.save(model.state_dict(), buffer)
...: buffer.seek(0)
...: loaded_state_dict = torch.load(buffer)
...: return loaded_state_dict
...:
...: def create_linear4bit(in_features, out_features, quant_storage):
...: layer = bnb.nn.Linear4bit(in_features, out_features, quant_storage=quant_storage)
...: layer.weight.data.normal_(0, 1)
...: return layer.cuda() # Move to CUDA to trigger quantization
...:
...: def check_quant_storage(state_dict):
...: weight_state = {k: v for k, v in state_dict.items() if k.startswith('weight')}
...: print("Quantization state keys:", weight_state.keys())
...: for k, v in weight_state.items():
...: if isinstance(v, torch.Tensor):
...: print(f"{k} dtype: {v.dtype}")
...:
In [2]: print("Testing with float16 quant_storage:")
...: model_float16 = create_linear4bit(10, 20, quant_storage=torch.float16)
...: loaded_state_dict = save_and_load_model(model_float16)
...: check_quant_storage(loaded_state_dict)
...: print()
...:
...: print("Testing with bfloat16 quant_storage:")
...: model_bfloat16 = create_linear4bit(10, 20, quant_storage=torch.bfloat16)
...: loaded_state_dict = save_and_load_model(model_bfloat16)
...: check_quant_storage(loaded_state_dict)
...: print()
...:
...: print("Testing with float32 quant_storage:")
...: model_float32 = create_linear4bit(10, 20, quant_storage=torch.float32)
...: loaded_state_dict = save_and_load_model(model_float32)
...: check_quant_storage(loaded_state_dict)
...: print()
...:
...: print("Testing with uint8 quant_storage (default):")
...: model_uint8 = create_linear4bit(10, 20, quant_storage=torch.uint8)
...: loaded_state_dict = save_and_load_model(model_uint8)
...: check_quant_storage(loaded_state_dict)
Testing with float16 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.float16
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with bfloat16 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.bfloat16
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with float32 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.float32
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with uint8 quant_storage (default):
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.uint8
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
In this example, the weights retain the right dtype (the one that holds the packed quantized weights) despite serialization. quant_state.bitsandbytes__fp4
is just the packed representation of non-tensor quantization state information and shouldn't cause any issues in this context, imo.
Let me know if I misunderstood anything or if you have any further concers / questions.
I'll be sure to dig into this more tmr :)
I was just talking to @matthewdouglas about this via PM. I think this probably still needs another iteration. My understanding is that the
quant_storage
dtype is actually supported for serialization in BNB, so we gotta take this into account:
You're right, so my comment about that is incorrect. What I notice is that uint8 is the default for Params4bit.__new__
so I think we should keep that the same.
Ok, after very thorough review I have to say that this is great work. Thanks for cleaning this part of the code up with this more correct and complete logic.
Regarding the serialization it just need this small fix to pick up on the quant_storage dtype based on serialized tensor.
Test suite is all green despite the usual flakiness (which I double-checked for everything by hand).
Will merge this and then trigger the HF integration tests on main
one more time and do the release right after. Great team work :D
Really helpful and good work. Thanks @matthewdouglas ❤️ 🤗
This change ensures we don't lose track of a non-default
quant_storage
option or quantization state when moving between CPU and GPU.cc: @Titus-von-Koeller @SunMarc