huggingface / peft

🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
https://huggingface.co/docs/peft
Apache License 2.0
15.81k stars 1.52k forks source link

Lora initialisation with olora and pissa not working with quantisation. #1999

Closed Haakooto closed 1 week ago

Haakooto commented 1 month ago

System Info

transformers version: 4.43.2
peft version: 0.12.0
Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
Python version: 3.12.4
Huggingface_hub version: 0.24.2
Safetensors version: 0.4.3
Accelerate version: 0.33.0
Accelerate config: not found
PyTorch version (GPU?): 2.4.0+cu121 (True)
Tensorflow version (GPU?): not installed (NA)
Flax version (CPU?/GPU?/TPU?): not installed (NA)
Jax version: not installed
JaxLib version: not installed
Using distributed or parallel set-up in script?: no
Using GPU in script?: yes
GPU type: NVIDIA H100 PCIe

Who can help?

@BenjaminBossan @sayakpaul

Information

Tasks

Reproduction

Problem araises when using "olora" or "pissa" as adapter initialisation when loading model in 4bit.

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)
from peft import LoraConfig, TaskType, get_peft_model

def main(quantize: bool, lora_init: str):
    device = torch.device("cuda:0")
    model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

    tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False)
    tokenizer.pad_token = tokenizer.eos_token

    q4bit_conf = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_storage=torch.bfloat16,
    )
    model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            quantization_config=q4bit_conf if quantize else None,
            device_map=device
        )

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules="all-linear",
        init_lora_weights=lora_init,
    )

    model = get_peft_model(model, peft_config)

    d = model.device
    x = tokenizer("hello world", return_tensors="pt").to(d)["input_ids"]
    model(x)  # * here comes the crash

if __name__ == "__main__":
    quantize = True
    lora_init = "olora"
    main(quantize, lora_init)

Expected behavior

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x2048 and 1x1)

Traceback (not full traceback):

    File "/home/local/NTU/hator/miniconda3/envs/py12/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 438, in forward
      query_states = self.q_proj(hidden_states)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/local/NTU/hator/miniconda3/envs/py12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/local/NTU/hator/miniconda3/envs/py12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/local/NTU/hator/miniconda3/envs/py12/lib/python3.12/site-packages/peft/tuners/lora/bnb.py", line 489, in forward
      output = lora_B(lora_A(dropout(x))) * scaling
                      ^^^^^^^^^^^^^^^^^^
    File "/home/local/NTU/hator/miniconda3/envs/py12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/local/NTU/hator/miniconda3/envs/py12/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
      return forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/local/NTU/hator/miniconda3/envs/py12/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 117, in forward
      return F.linear(input, self.weight, self.bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x2048 and 1x1)

While I have not read the code to make sure, I suspect the issue can be explained by this: The quantisation process happens before the application of lora-adapter. This flattens the original weight matrices. When initialising the lora-adapter with random entries, this flattening causes no problem as (I suspect) only the original shape is used to create it. However, olora and pissa initialisation computes the QR and SVD decomposition of the original matrix. It seems that the flattening is not respected, so the decomposition is calculated from the (N x 1)-tensor. The resulting lora_A-matrix is a (1x1) tensor, while the base-tensor is filled with nan. This causes no issues when initialising, and the code only crashes when an input is passed to the model.

BenjaminBossan commented 1 month ago

Thanks a lot for reporting. OLoRA should work with quantization, but as your reproducer shows, there is something amiss.

Interestingly, when using the OLoRA example script (which is not all that different), it can work with quantization, but not with all models it seems. E.g.:

  File "/home/name/work/forks/peft/examples/olora_finetuning/olora_finetuning.py", line 166, in <module>
    train(
  File "/home/name/work/forks/peft/examples/olora_finetuning/olora_finetuning.py", line 130, in train
    trainer.train()
  File "/home/name/work/clones/transformers/src/transformers/trainer.py", line 1945, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/name/work/clones/transformers/src/transformers/trainer.py", line 2286, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[...]

  File "/home/name/work/forks/peft/src/peft/tuners/lora/bnb.py", line 467, in forward
    result = self.base_layer(x, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/accelerate/hooks.py", line 169, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 477, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py", line 579, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/name/anaconda3/envs/peft/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py", line 509, in forward
    output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x2048 and 256x2048)
  0%|              

I didn't have time to investigate yet, will probably have time next week. Pinging @tokenizer-decode

tokenizer-decode commented 1 month ago

Weird. We don't do anything model specific, we rely on PEFT's quantization. It should be related to that.

BenjaminBossan commented 1 month ago

I did some further research, though unfortunately I still haven't found the solution.

One difference that I spotted between the scripts is the usage of bnb_4bit_quant_storage, but even when removed, there is still an error.

Didding deeper, it appears that something goes wrong during the dequantization step:

https://github.com/huggingface/peft/blob/41c274ecac6247a63e3b10f8536904f3ead82213/src/peft/utils/integrations.py#L89

The output is a 2048x2048 shaped float tensor. However, I think it should be a float tensor of shape 2097152x1 because bnb uses flat tensors (and 2097152 == 2048**2 / 2). For reference, I checked the LoftQ implementation and there I get exactly this flat shape. It is unclear to me why the shapes are different, even though both code paths call bnb.functional.dequantize_4bit(weight.data, weight.quant_state). I checked the quant_state too but AFAICT it's the same between the two.

Anyway, because of the wrong shape, we assign an incorrectly shaped (and all zeros) tensor as the new bnb weight here:

https://github.com/huggingface/peft/blob/41c274ecac6247a63e3b10f8536904f3ead82213/src/peft/tuners/lora/layer.py#L189

which is most likely the reason for the shape error later during forward.

When I have time to further investigate, I'll let you know. If any of you have an idea what I'm missing, please let me know.

Haakooto commented 4 weeks ago

I poked around in the olora_init-function in peft/src/peft/tuners/lora/layer.py. As @BenjaminBossan mentions, the difference between my script and the olora-example script is the bnb_4bit_quant_storage-flag in the BitsAndBytesConfig. Both scripts crash because of shape-mismatch, both with bnb_4bit_quant_storage=None (default) and with bnb_4bit_quant_storage=torch.bfloat16 (which I used because of this page).

However, the relevant shapes are different. As can be seen in the tracebacks posted by me and @BenjaminBossan, with bnb_4bit_quant_storage=torch.bfloat16, mat2 has size (1x1), while with bnb_4bit_quant_storage=None, it has size (NxM). This is because olora_init does not identify the Linear4bit-instance as quantised as the only check for this is its dtype, which is bfloat16, causing the original weights not to be dequantised before computing the QR-decomposition.

While this does not shed light on the ultimate issue, this is another bug.

BenjaminBossan commented 4 weeks ago

Thanks for the further investigation @Haakooto. I have a tentative solution in #2011, it would be great if you could test it. I consider this WIP as I'm uncertain if the way the PR fixes the issue is optimal.

Haakooto commented 3 weeks ago

Finally had a chance to look at it now. This does indeed work. I noticed the solution only covers olora, not pissa. It was straight forward to copy the code that fix the issue over. Though if further improvements are comming, as you say, this is not urgent. Thank you!

BenjaminBossan commented 3 weeks ago

Thanks @Haakooto for testing, this is good to know. Let's focus on OLoRA for now and deal with PiSSA later.

I discovered a bug in the PR as it didn't correctly deal with 8bit bnb, but that should now also be fixed.

44670 commented 2 weeks ago

I have also met this issue with bnb+olora

BenjaminBossan commented 2 weeks ago

@44670 The PR is still not merged, but you could install from the https://github.com/BenjaminBossan/peft/tree/fix-olora-bnb branch if you want to test it right now.