bitsandbytes-foundation / bitsandbytes

Accessible large language models via k-bit quantization for PyTorch.
https://huggingface.co/docs/bitsandbytes/main/en/index
MIT License
6.03k stars 606 forks source link

Pretrained Causal LM cannot be loaded in 4bit/8bit #1331

Open adrienchaton opened 3 weeks ago

adrienchaton commented 3 weeks ago

System Info

Name Version Build Channel

transformers 4.41.0.dev0 pypi_0 pypi pytorch 2.1.2 py3.9_cuda12.1_cudnn8.9.2_0 pytorch pytorch-cuda 12.1 ha16c6d3_5 pytorch pytorch-lightning 2.3.0 pyhd8ed1ab_0 conda-forge pytorch-mutex 1.0 cuda pytorch bitsandbytes 0.43.3 pypi_0 pyp

Reproduction

import torch
from transformers import AutoConfig, AutoModelForCausalLM

evo_ckpt="evo-1-8k-base"
config = AutoConfig.from_pretrained(f"togethercomputer/{evo_ckpt}", trust_remote_code=True, revision="1.1_fix")
config.use_cache = False

model = AutoModelForCausalLM.from_pretrained(f"togethercomputer/{evo_ckpt}", config=config, trust_remote_code=True, revision="1.1_fix", torch_dtype=torch.bfloat16)
# here we can load the model successfully

model = AutoModelForCausalLM.from_pretrained(f"togethercomputer/{evo_ckpt}", config=config, trust_remote_code=True, revision="1.1_fix", torch_dtype=torch.bfloat16, load_in_8bit=True)
# here we cannot load

model = AutoModelForCausalLM.from_pretrained(f"togethercomputer/{evo_ckpt}", config=config, trust_remote_code=True, revision="1.1_fix", torch_dtype=torch.bfloat16, load_in_4bit=True)
# here we cannot load

for both 4bit and 8bit, we get the same error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 558, in from_pretrained
    return model_class.from_pretrained(
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3572, in from_pretrained
    hf_quantizer.preprocess_model(
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/site-packages/transformers/quantizers/base.py", line 182, in preprocess_model
    return self._process_model_before_weight_loading(model, **kwargs)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/site-packages/transformers/quantizers/quantizer_bnb_8bit.py", line 237, in _process_model_before_weight_loading
    self.modules_to_not_convert = get_keys_to_not_convert(model)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/site-packages/transformers/integrations/bitsandbytes.py", line 290, in get_keys_to_not_convert
    tied_model = deepcopy(model)  # this has 0 cost since it is done inside `init_empty_weights` context manager`
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 296, in _reconstruct
    value = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/wuppertal/gnlzm/mambaforge/envs/bm-llms-minimal/lib/python3.9/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: 'NoneType' object is not callable

Expected behavior

Hello and thanks for providing this great library for quantizing HF LLMs!

I am successfully finetuning the model with LoRA, after that I am looking into QLoRA in order to reduce memory footprint and allow finetuning with larger context length. I read the different documentations, one must first load and quantize the model, and then setup PEFT and train.

However when I try to load this model with quantization, it throws an error which I attached the trace of ... It is the same whether I try 4 or 8 bit, and whether I configure quantization in the from_pretrained method or pass a quantization_config.

Any idea on whether it would be possible to quantize this model please? And if yes, how to get around this issues when loading in quantized mode?

Thanks in advance

matthewdouglas commented 3 weeks ago

I would say this issue most likely would belong in huggingface/transformers, or maybe the model code directly, which is where dotdict is being used and breaking this.

With that said, I would suggest using a quantization_config here and specify llm_int8_skip_modules by hand. In particular, you might at least want the following:

llm_int8_skip_modules=["embedding_layer", "poles", "residues"]

cc: @SunMarc in case he has any other advice here!

Hope this helps!

adrienchaton commented 3 weeks ago

Hi @matthewdouglas thanks a lot for your quick help!

I was skipping the quantization_config to keep the code snippet as simple as possible but the issue was the same. Now that you advised to configure llm_int8_skip_modules the model can load, and to my surprise in both 8bit and 4bit. I still have to run some trainings to see if things run through and behave as expected, but this is already a great move forward, will keep you posted on how it goes.

Right now, this is how I load in 8bit and 4bit

import torch
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import BitsAndBytesConfig

evo_ckpt="evo-1-8k-base"
config = AutoConfig.from_pretrained(f"togethercomputer/{evo_ckpt}", trust_remote_code=True, revision="1.1_fix")
config.use_cache = False

# loading in 8bit
bnb_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["embedding_layer", "poles", "residues"])
model = AutoModelForCausalLM.from_pretrained(f"togethercomputer/{evo_ckpt}", config=config, trust_remote_code=True, revision="1.1_fix", quantization_config=bnb_config, torch_dtype=torch.bfloat16)
# warning: "Some weights of StripedHyenaModelForCausalLM were not initialized from the model checkpoint at togethercomputer/evo-1-8k-base and are newly initialized: ['backbone.unembed.weight']
# You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
# ... since model.config.tie_embeddings is True, we could override this init with model.backbone.unembed = model.backbone.embedding_layer
# https://github.com/togethercomputer/stripedhyena/blob/main/stripedhyena/model.py#L342
model.backbone.unembed = model.backbone.embedding_layer

# loading in 4bit with parameters recommended for QLoRA
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_storage=torch.bfloat16, llm_int8_skip_modules=["embedding_layer", "poles", "residues"])
model = AutoModelForCausalLM.from_pretrained(f"togethercomputer/{evo_ckpt}", config=config, trust_remote_code=True, revision="1.1_fix", quantization_config=bnb_config, torch_dtype=torch.bfloat16)
model.backbone.unembed = model.backbone.embedding_layer

I tried to add "unembed" to llm_int8_skip_modules but the init would still alter the pretrained weight from which I want to finetune .. I am not familiar yet with your library, is there a better way to go around re-init of unembed weights?

matthewdouglas commented 3 weeks ago

I would have expected tie_weights() to take care of that, but maybe it's not implemented in this model? I'm not very familiar with StripedHyena; this is probably a good question to ask on their repo.

PretrainedModel does have a _tie_or_clone_weights(output_embeddings, input_embeddings) that you can try too, but in general I think you have the right idea.

SunMarc commented 3 weeks ago

Hey @adrienchaton, can you share the model architecture after loading wit bnb ? It's strange that despite adding "unembed" to llm_int8_skip_modules, it still gets altered. Also, this is a remote code where we don't have control over the code. I see that they are missing a get_output_embedding function in their code that is required if one sets tie_word_embeddings = True. This is probably the reason why the weights are being re-init for unembed.

def tie_weights(self):
     if getattr(self.config, "tie_word_embeddings", True):
          # right now, this is returning `None`
          output_embeddings = self.get_output_embeddings()
          if output_embeddings is not None:
              self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
matthewdouglas commented 2 weeks ago

@SunMarc raises a good point about get_output_embeddings lacking an implementation. The other thing I find notable on this model's configuration is that the attribute is named tie_embeddings and not tie_word_embeddings which would also cause the behavior you're seeing.

adrienchaton commented 2 weeks ago

Hi @SunMarc @matthewdouglas and thanks a lot for looking into this.

Working with a custom HF model which doesnt implement all methods and conventions that bitsandbytes expects is indeed making things very difficult ...

To answer your question, if I set llm_int8_skip_modules=["embedding_layer", "poles", "residues", "unembed"] then the 4bit model loads as

StripedHyenaModelForCausalLM(
  (backbone): StripedHyena(
    (embedding_layer): VocabParallelEmbedding(512, 4096)
    (norm): RMSNorm()
    (unembed): VocabParallelEmbedding(512, 4096)
    (blocks): ModuleList(
      (0-7): 8 x ParallelGatedConvBlock(
        (pre_norm): RMSNorm()
        (post_norm): RMSNorm()
        (filter): ParallelHyenaFilter()
        (projections): Linear4bit(in_features=4096, out_features=12288, bias=True)
        (out_filter_dense): Linear4bit(in_features=4096, out_features=4096, bias=True)
        (mlp): ParallelGatedMLP(
          (l1): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l2): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l3): Linear4bit(in_features=10928, out_features=4096, bias=False)
        )
      )
      (8): AttentionBlock(
        (pre_norm): RMSNorm()
        (post_norm): RMSNorm()
        (inner_mha_cls): MHA(
          (rotary_emb): RotaryEmbedding()
          (Wqkv): Linear4bit(in_features=4096, out_features=12288, bias=True)
          (inner_attn): FlashSelfAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (inner_cross_attn): FlashCrossAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (out_proj): Linear4bit(in_features=4096, out_features=4096, bias=True)
        )
        (mlp): ParallelGatedMLP(
          (l1): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l2): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l3): Linear4bit(in_features=10928, out_features=4096, bias=False)
        )
      )
      (9-15): 7 x ParallelGatedConvBlock(
        (pre_norm): RMSNorm()
        (post_norm): RMSNorm()
        (filter): ParallelHyenaFilter()
        (projections): Linear4bit(in_features=4096, out_features=12288, bias=True)
        (out_filter_dense): Linear4bit(in_features=4096, out_features=4096, bias=True)
        (mlp): ParallelGatedMLP(
          (l1): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l2): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l3): Linear4bit(in_features=10928, out_features=4096, bias=False)
        )
      )
      (16): AttentionBlock(
        (pre_norm): RMSNorm()
        (post_norm): RMSNorm()
        (inner_mha_cls): MHA(
          (rotary_emb): RotaryEmbedding()
          (Wqkv): Linear4bit(in_features=4096, out_features=12288, bias=True)
          (inner_attn): FlashSelfAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (inner_cross_attn): FlashCrossAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (out_proj): Linear4bit(in_features=4096, out_features=4096, bias=True)
        )
        (mlp): ParallelGatedMLP(
          (l1): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l2): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l3): Linear4bit(in_features=10928, out_features=4096, bias=False)
        )
      )
      (17-23): 7 x ParallelGatedConvBlock(
        (pre_norm): RMSNorm()
        (post_norm): RMSNorm()
        (filter): ParallelHyenaFilter()
        (projections): Linear4bit(in_features=4096, out_features=12288, bias=True)
        (out_filter_dense): Linear4bit(in_features=4096, out_features=4096, bias=True)
        (mlp): ParallelGatedMLP(
          (l1): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l2): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l3): Linear4bit(in_features=10928, out_features=4096, bias=False)
        )
      )
      (24): AttentionBlock(
        (pre_norm): RMSNorm()
        (post_norm): RMSNorm()
        (inner_mha_cls): MHA(
          (rotary_emb): RotaryEmbedding()
          (Wqkv): Linear4bit(in_features=4096, out_features=12288, bias=True)
          (inner_attn): FlashSelfAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (inner_cross_attn): FlashCrossAttention(
            (drop): Dropout(p=0.0, inplace=False)
          )
          (out_proj): Linear4bit(in_features=4096, out_features=4096, bias=True)
        )
        (mlp): ParallelGatedMLP(
          (l1): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l2): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l3): Linear4bit(in_features=10928, out_features=4096, bias=False)
        )
      )
      (25-31): 7 x ParallelGatedConvBlock(
        (pre_norm): RMSNorm()
        (post_norm): RMSNorm()
        (filter): ParallelHyenaFilter()
        (projections): Linear4bit(in_features=4096, out_features=12288, bias=True)
        (out_filter_dense): Linear4bit(in_features=4096, out_features=4096, bias=True)
        (mlp): ParallelGatedMLP(
          (l1): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l2): Linear4bit(in_features=4096, out_features=10928, bias=False)
          (l3): Linear4bit(in_features=10928, out_features=4096, bias=False)
        )
      )
    )
  )
)

and as said before it gives a warning about the unembed parameter to be re-initialized.

However during forward I get an error from the FlashAttention library

flash_attn/modules/mha.py", line 100, in forward
    assert qkv.dtype in [torch.float16, torch.bfloat16]

I can successfully fix this by adding Wqkv to llm_int8_skip_modules which then loads the attention block as

(8): AttentionBlock(
(pre_norm): RMSNorm()
(post_norm): RMSNorm()
(inner_mha_cls): MHA(
  (rotary_emb): RotaryEmbedding()
  (Wqkv): Linear(in_features=4096, out_features=12288, bias=True)
  (inner_attn): FlashSelfAttention(
    (drop): Dropout(p=0.0, inplace=False)
  )
  (inner_cross_attn): FlashCrossAttention(
    (drop): Dropout(p=0.0, inplace=False)
  )
  (out_proj): Linear4bit(in_features=4096, out_features=4096, bias=True)
)
(mlp): ParallelGatedMLP(
  (l1): Linear4bit(in_features=4096, out_features=10928, bias=False)
  (l2): Linear4bit(in_features=4096, out_features=10928, bias=False)
  (l3): Linear4bit(in_features=10928, out_features=4096, bias=False)
)
)

With this setting the 4bit model can run both forward and backward.

If loaded in 8bit I would get an error during the backward pass, with mismatched sizes, FYI it says

bitsandbytes/autograd/_functions.py", line 481, in backward
    grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
RuntimeError: shape '[1, 1000, 10928]' is invalid for input of size 10944000

with 1000 being the sequence length used for debugging and 10928 being an output size in the MLPs, so it seems that gradient dimensions wouldnt fit with the model architecture when loaded in 8bit.

Ideally I was more interested to run in 4bit than 8bit (i.e. trying the highest reduction in memory footprint) so I didnt try to solve this issue with the backward in 8bit. But with batch size 1 and an A100-80GB GPU, I get OOM above MSL=3000 tokens for both the unquantized LoRA and the 4bit LoRA ... I guess the limited compatibility of this custom architecture with BnB probably reduces the gains we can get from quantization ...

Also, when I run PEFT LoRA without quantization, I merge the adapters into the base class before saving so I can reload the finetuned model as I would load the original model. After doing a test run with LoRA and 4bit quantization, I got an error saying that the dequantize method isnt implemented. So if I dont run dequantize, the saved parameters do not match the sizes expected by the base class (and if loading with ignore_mismatched_sizes=True then we get a lot of parameters left randomly init.).

Regardless of the open question wrt. gains in memory in 4bit and saving/loading, I will try to run a longer finetuning to see if I get a similar convergence with LoRA BF16 and LoRA 4bit. But for my understanding, I would like to ask what is the recommended way to handle checkpointing with PEFT 4bit models please.

Is the below correct? Or should I avoid trying to dequantize the model?

Thanks for looking into this issue, although most of it is due to working with this custom model which is beyond the scope of your library ...