yxli2123 / LoftQ

MIT License
180 stars 15 forks source link

fake and true quantization don't match #7

Closed BaohaoLiao closed 7 months ago

BaohaoLiao commented 7 months ago

Hi,

As a debugging way, I want to check whether the fake and true quantized model's weights have the same value. Here is how I implement it:

config = AutoConfig.from_pretrained("LoftQ/Llama-2-7b-hf-bit4-rank64", trust_remote_code=False)
loftq_fp16 = AutoModelForCausalLM.from_pretrained(
    "LoftQ/Llama-2-7b-hf-bit4-rank64",
    trust_remote_code=False,
    config=config,
    token="xxx"
)

loftq_fp4 = AutoModelForCausalLM.from_pretrained(
    "LoftQ/Llama-2-7b-hf-bit4-rank64",
    config=config,
    low_cpu_mem_usage=True,
    load_in_4bit=True,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        llm_int8_has_fp16_weight=False,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False,
        bnb_4bit_quant_type='nf4',
    ),
    token="xxx"
)

Then I print out some weight values as: print(loftq_fp16.state_dict()['model.layers.0.self_attn.q_proj.weight']) The output is:

tensor([[-0.0062, -0.0148, -0.0022,  ...,  0.0045,  0.0017, -0.0036],
        [ 0.0142, -0.0043,  0.0028,  ..., -0.0093, -0.0114,  0.0076],
        [-0.0146,  0.0126,  0.0005,  ...,  0.0063,  0.0188, -0.0031],
        ...,
        [ 0.0013,  0.0109, -0.0003,  ...,  0.0098, -0.0298,  0.0097],
        [ 0.0256,  0.0102,  0.0032,  ..., -0.0334, -0.0156, -0.0123],
        [-0.0134, -0.0066,  0.0018,  ...,  0.0181,  0.0166, -0.0082]])

For loftq_fp4, I do it in this way:

import copy
from bitsandbytes.functional import dequantize_4bit
with torch.no_grad():
    for name, module in loftq_fp4.named_modules():
        if name == "model.layers.0.self_attn.q_proj.base_layer":
            quant_state = copy.deepcopy(module.weight.quant_state)
            dtype = torch.float16
            weights = dequantize_4bit(module.weight.data, quant_state=quant_state, quant_type="nf4").to(dtype)
            print(weights)

The output is:

tensor([[-0.0072, -0.0153, -0.0035,  ...,  0.0047,  0.0000, -0.0054],
        [ 0.0116,  0.0000,  0.0000,  ..., -0.0108, -0.0108,  0.0061],
        [-0.0228,  0.0199,  0.0000,  ...,  0.0096,  0.0195,  0.0000],
        ...,
        [ 0.0000,  0.0141,  0.0000,  ...,  0.0124, -0.0305,  0.0124],
        [ 0.0251,  0.0092,  0.0045,  ..., -0.0317, -0.0172, -0.0111],
        [-0.0153, -0.0072,  0.0031,  ...,  0.0188,  0.0144, -0.0079]],
       device='cuda:0', dtype=torch.float16)

We can see they are quite different, which means the fake quantization doesn't truly reflect the true quantization performance.

yxli2123 commented 7 months ago

Thanks for pointing out. This is because LoftQ/Llama-2-7b-hf-bit4-rank64 used self-implemented nf4 quantization method which is not exact the same as nf4 quantization in bitsandbytes. To fix it, please try LoftQ/Llama-2-7b-hf-4bit-64rank.

Moreover, our method is not constrained to a specific quantization method. Either self-implemented or the one in bitsandbytes can achieve the on par results.

BaohaoLiao commented 7 months ago

Thank you for this clarification.

I understand your method is not limited to any quantization function. However, you still use the bitsandbytes as a backend for memory-efficient fine-tuning. If you use custom quantization (like self-implemented nf4 quantization), doesn't it introduce some mismatch because of the different quantization functions between fine-tuning and custom LoRA initialization?

Said you can obtain a perfect LoRA initialization as W = Q + AB, where Q = self_implemented_nf4(W). When you use bitsandbytes to fine-tune, Q_new = bitsandbytes_nf4(Q), results in W is not equal to Q_new + AB.

BaohaoLiao commented 7 months ago

In addition, may I ask what the default T is for llama?

yxli2123 commented 7 months ago

LoftQ/Llama-2-7b-hf-4bit-64rank is quantized with bitsandbytes method and does not have discrepancy between true and fake quantization. The default alternating step T for llama-2 is 5.