artidoro / qlora

QLoRA: Efficient Finetuning of Quantized LLMs
https://arxiv.org/abs/2305.14314
MIT License
10.05k stars 821 forks source link

LORA Merge fails in 4-bit mode #28

Open KKcorps opened 1 year ago

KKcorps commented 1 year ago

Trained vicuna-13b-1.1 LORA in 4bit

Now trying to merge it for running generations but it fails with the following error

python3.11/site-packages/peft/tuners/lora.py", line 352, in merge_and_unload
    raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")
ValueError: Cannot merge LORA layers when the model is loaded in 8-bit mode
KKcorps commented 1 year ago

Strangely the checkpoints in the output directory are also very small (adapater_model.bin is just 400 bytes).

However, the optimizer.pth is around 1GB.

taishan1994 commented 1 year ago

Strangely the checkpoints in the output directory are also very small (adapater_model.bin is just 400 bytes).

However, the optimizer.pth is around 1GB.

I also encountered this problem. My adapter_model.bin is also only 443kb

magneter commented 1 year ago

has the problem solved?

santapo commented 1 year ago

+1

kaisawind commented 1 year ago

+1 But i think it is. https://github.com/huggingface/peft/blob/0e33ac1efe1564143bffd248906d9e17864017dc/src/peft/tuners/lora.py#L461-L463 2023-07-28_16-41

eugene-yh commented 1 year ago

I have looked into the source code of HF and BitsAndBytes. It seems to me the ultimate reason why this is not supported is that the under-the-hood bnb.nn.Linear4bit module is not designed to be mergable by adding the lora weights. bnb.nn.Linear4bit only contains a weight matrix (which unfortunately is not the usual weight matrix in a usual torch.nn.Linear) that is of type torch.uint8 and it does not provide an addition operator.

However, for those who are curious and seeking a workaround, here is my trick and it seems working fine. The idea comes from the fact that although there is no addition operator, bnb does provide a forward (basically matmul operator) method. So, using a simple math: W = W*I (where I is an identity matrix), we can simply feed an identity matrix to bnb.nn.Linear4bit's forward method and we should get the mathematically equivalent weight matrix of the linear layer.

Note that the above is only the solution for 4-bit linear layer without bias. However, we can also extend this idea to cases where there is a bias. We can use b = W O + b (where O is a zero matrix) to get the bias b and then use W = (WI + b) - b to recover the W. However, for most LLMs today bias is None, at least for LLaMA. So, we don't have to do this in most cases.

Now, my code snippet is as follows (NOTE: no handling of bias). You can use it to get the mathematically equivalent non-quantized model. Then, you load the generated model and then use HF's merge_and_unload() to merge normally. Hope it helps!

def dequantize_model(model, tokenizer, to='./dequantized_model', dtype=torch.float16):
    """
    'model': the peftmodel you loaded with qlora.
    'tokenizer': the model's corresponding hf's tokenizer.
    """

    import peft
    import json
    import shutil
    import torch
    from peft.utils import _get_submodules
    import os
    if os.path.exists(to):
        shutil.rmtree(to)

    os.makedirs(to, exist_ok=True)

    cls = peft.tuners.lora.Linear4bit

    base_model = model.base_model.model

    with torch.no_grad():
        for name, module in base_model.named_modules():
            if isinstance(module, cls):
                print(f"Dequantizing `{name}`...")
                if module.bias is None:
                    module.disable_adapters = True   # so peft.tuners.lora.Linear4bit.foward is the same as bnb.nn.Linear4bit
                    dequantized_weight = module(torch.eye(module.in_features, dtype=dtype).to(module.weight.device))
                    dequantized_weight = torch.transpose(dequantized_weight, 0, 1).to("cpu")
                    new_module = torch.nn.Linear(module.in_features, module.out_features, bias=None)
                    new_module.weight = torch.nn.Parameter(dequantized_weight)
                else:
                    # TODO: handle when bias is not None
                    raise NotImplementedError

                parent, target, target_name = _get_submodules(base_model, name)
                setattr(parent, target_name, new_module)

        # a hack, setting this to avoid hf's saving error because hf
        # itself does not support saving a model that is registered to be loaded in 4bit.
        base_model.is_loaded_in_4bit = False

        print("Saving dequantized model...")
        base_model.save_pretrained(to)
        tokenizer.save_pretrained(to)
        config_data = json.loads(open(os.path.join(to, 'config.json'), 'r').read())
        config_data.pop("quantization_config", None)
        config_data.pop("pretraining_tp", None)
        with open(os.path.join(to, 'config.json'), 'w') as config:
            config.write(json.dumps(config_data, indent=2))
eugene-yh commented 1 year ago

a followup to my previous post: interestingly, if we delete the line module.disable_adapters = True, the function is equivalent to dequantizing + merging Lora. This should resolve the issue completely. :)

webpolis commented 1 year ago

a followup to my previous post: interestingly, if we delete the line module.disable_adapters = True, the function is equivalent to dequantizing + merging Lora. This should resolve the issue completely. :)

What would be the implementation with bias?