yuanzhoulvpi2017 / zero_nlp

中文nlp解决方案(大模型、数据、模型、训练、推理)
MIT License
2.78k stars 349 forks source link

4/8bit量化的问题及源码阅读的问题 #172

Open wangq326 opened 3 months ago

wangq326 commented 3 months ago

看了transformers的源码,我的环境里的transformers的版本是4.39.3。我在modeling_utils.py文件中并未发现调用replace_with_bnb_linear,我看到的代码是下面的

if load_in_4bit or load_in_8bit:
            if quantization_config is not None:
                raise ValueError(
                    ...
                )

            config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters}
            config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit}
            quantization_config, kwargs = BitsAndBytesConfig.from_dict(
                config_dict=config_dict, return_unused_kwargs=True, **kwargs
            )

而不是你在文章中写的使用replace_with_bnb_linear,如果方便的话,请解答一下这是因为版本问题还是我哪里搞错了?

其次,我是想使用replace_with_bnb_linear在LLM之中替换一些线性层使用自己设计的结构,并且想使用4/8bit量化这些部分,跟load_in_4/8bit保持一致。尽可能让替换后的模块,和替换前的模块的量化策略保持一致。但是直接使用replace_with_bnb_linear之后,在forward计算中会有如下的报错

FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  ...
  in forward
    x = self.linear(x)
  File "myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "myenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "myenv/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 256, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
  File "myenv/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py", line 566, in matmul_4bit
    assert quant_state is not None
AssertionError

因为没有找到transformers中使用replace_with_bnb_linear的地方是怎么写的,所以这里不知道quant_state应该是否要人工设定或者怎么给出。有劳大佬抽空看看,先谢了

wangq326 commented 3 months ago

上面是使用4bit量化时的报错,我尝试了使用以下的代码进行处理,但是在处理中存在一些问题:

quantization_config = llm_model.config.quantization_config
quantization_config.quant_state = llm_model.layers[0].mlp.up_proj.weight.quant_state
from transformers.integrations.bitsandbytes import replace_with_bnb_linear
from bitsandbytes.nn import Linear8bitLt, Linear4bit
mylayer = replace_with_bnb_linear(
    MyLayer(**),
    quantization_config=quantization_config
    )

mylayer.to_empty(device=llm_model.device)
for name, layer in mylayer.named_modules():
    if isinstance(layer, Linear4bit):
        layer.weight.quant_state = quantization_config.quant_state

mylayer(x)

这里的报错大概是因为我直接套用up_proj的权重中的quant_state,会导致一些形状上的不匹配。但是我这里只考虑了是否能将其运行,暂时没有考虑计算逻辑的正确性。因为直接套用quant_state中的absmax等参数的做法是不合理的,absmax参数看起来是与4bit量化很相关的,使用其他权重的参数可能并不合理。