bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.14k stars 616 forks source link

Fixes for quant_storage and CPU offloading #1279

Closed matthewdouglas closed 2 months ago

matthewdouglas commented 2 months ago

This change ensures we don't lose track of a non-default quant_storage option or quantization state when moving between CPU and GPU.

cc: @Titus-von-Koeller @SunMarc

Titus-von-Koeller commented 2 months ago

I was just talking to @matthewdouglas about this via PM. I think this probably still needs another iteration. My understanding is that the quant_storage dtype is actually supported for serialization in BNB, so we gotta take this into account:

In [1]: import torch
   ...: import bitsandbytes as bnb
   ...: import io
   ...: 
   ...: def save_and_load_model(model):
   ...:     buffer = io.BytesIO()
   ...:     torch.save(model.state_dict(), buffer)
   ...:     buffer.seek(0)
   ...:     loaded_state_dict = torch.load(buffer)
   ...:     return loaded_state_dict
   ...: 
   ...: def create_linear4bit(in_features, out_features, quant_storage):
   ...:     layer = bnb.nn.Linear4bit(in_features, out_features, quant_storage=quant_storage)
   ...:     layer.weight.data.normal_(0, 1)
   ...:     return layer.cuda()  # Move to CUDA to trigger quantization
   ...: 
   ...: def check_quant_storage(state_dict):
   ...:     weight_state = {k: v for k, v in state_dict.items() if k.startswith('weight')}
   ...:     print("Quantization state keys:", weight_state.keys())
   ...:     for k, v in weight_state.items():
   ...:         if isinstance(v, torch.Tensor):
   ...:             print(f"{k} dtype: {v.dtype}")
   ...: 
In [2]: print("Testing with float16 quant_storage:")
   ...: model_float16 = create_linear4bit(10, 20, quant_storage=torch.float16)
   ...: loaded_state_dict = save_and_load_model(model_float16)
   ...: check_quant_storage(loaded_state_dict)
   ...: print()
   ...: 
   ...: print("Testing with bfloat16 quant_storage:")
   ...: model_bfloat16 = create_linear4bit(10, 20, quant_storage=torch.bfloat16)
   ...: loaded_state_dict = save_and_load_model(model_bfloat16)
   ...: check_quant_storage(loaded_state_dict)
   ...: print()
   ...: 
   ...: print("Testing with float32 quant_storage:")
   ...: model_float32 = create_linear4bit(10, 20, quant_storage=torch.float32)
   ...: loaded_state_dict = save_and_load_model(model_float32)
   ...: check_quant_storage(loaded_state_dict)
   ...: print()
   ...: 
   ...: print("Testing with uint8 quant_storage (default):")
   ...: model_uint8 = create_linear4bit(10, 20, quant_storage=torch.uint8)
   ...: loaded_state_dict = save_and_load_model(model_uint8)
   ...: check_quant_storage(loaded_state_dict)
Testing with float16 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.float16
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with bfloat16 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.bfloat16
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with float32 quant_storage:
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.float32
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8
Testing with uint8 quant_storage (default):
Quantization state keys: dict_keys(['weight', 'weight.absmax', 'weight.quant_map', 'weight.nested_absmax', 'weight.nested_quant_map', 'weight.quant_state.bitsandbytes__fp4'])
weight dtype: torch.uint8
weight.absmax dtype: torch.uint8
weight.quant_map dtype: torch.float32
weight.nested_absmax dtype: torch.float32
weight.nested_quant_map dtype: torch.float32
weight.quant_state.bitsandbytes__fp4 dtype: torch.uint8

In this example, the weights retain the right dtype (the one that holds the packed quantized weights) despite serialization. quant_state.bitsandbytes__fp4 is just the packed representation of non-tensor quantization state information and shouldn't cause any issues in this context, imo.

Let me know if I misunderstood anything or if you have any further concers / questions.

I'll be sure to dig into this more tmr :)

matthewdouglas commented 2 months ago

I was just talking to @matthewdouglas about this via PM. I think this probably still needs another iteration. My understanding is that the quant_storage dtype is actually supported for serialization in BNB, so we gotta take this into account:

You're right, so my comment about that is incorrect. What I notice is that uint8 is the default for Params4bit.__new__ so I think we should keep that the same.

Titus-von-Koeller commented 2 months ago

Ok, after very thorough review I have to say that this is great work. Thanks for cleaning this part of the code up with this more correct and complete logic.

Regarding the serialization it just need this small fix to pick up on the quant_storage dtype based on serialized tensor.

Test suite is all green despite the usual flakiness (which I double-checked for everything by hand).

Will merge this and then trigger the HF integration tests on main one more time and do the release right after. Great team work :D

Really helpful and good work. Thanks @matthewdouglas ❤️ 🤗