OpenGVLab / OmniQuant

[ICLR2024 spotlight] OmniQuant is a simple and powerful quantization technique for LLMs.
MIT License
663 stars 50 forks source link

why quantize opt-1.3b or llama 7b with W8A8 config loss is nan? #8

Closed MeJerry215 closed 11 months ago

MeJerry215 commented 12 months ago

image

the quant command is python main.py --model llama-7B --epochs 20 --calib_dataset c4 --nsamples 128 --batch_size 1 --seed 0 --eval_ppl --wbits 4 --abits 8 --lwc --let --output_dir quant_models/llama-7b-w4a8 --save_dir fake_models/llama-7b-w4a8

MeJerry215 commented 12 months ago

@ChenMnZ

ChenMnZ commented 12 months ago

Training with 8-bit quantization leads to NAN loss in mixed-precision training. Please refer to the updated code; we've implemented full-precision training for quantization bits greater than 8, enabling successful W4A8 training.

MeJerry215 commented 11 months ago

Training with 8-bit quantization leads to NAN loss in mixed-precision training. Please refer to the updated code; we've implemented full-precision training for quantization bits greater than 8, enabling successful W4A8 training.

thx a lot.

brisker commented 9 months ago

Training with 8-bit quantization leads to NAN loss in mixed-precision training.

“Training with 8-bit quantization leads to NAN loss in mixed-precision training.”

why is this happening? @ChenMnZ

ChenMnZ commented 9 months ago

@brisker During the backward of mixed-precision training, gradient is calculated through float16 format, which has a narrower representation scope.

For the 8-bit quantization, the step size of quantization is 2^8-1. Such large step size during backward would lead overflow for float16.

brisker commented 8 months ago

@ChenMnZ here: https://github.com/OpenGVLab/OmniQuant/blob/main/quantize/omniquant.py#L192

image

why are shifts not used in Llama model equivalent transformation?

xingchensong commented 8 months ago

@ChenMnZ here: https://github.com/OpenGVLab/OmniQuant/blob/main/quantize/omniquant.py#L192

image

why are shifts not used in Llama model equivalent transformation?

Also intrigued by the rationale behind this question.

ChenMnZ commented 8 months ago

The original LLaMa models do not contain bias parameters, and we find that shift operation can not bring benefit to weight-only quantization. Therefore, we only activate shifts equivalent transformation in the weight-activation quantization for LLaMa models.

As for the initialization, the initialization of shifts is crucial for OPT but negligible for LLaMa, so we simple init it as 1 for LLaMa.

Xingrun-Xing commented 6 months ago

我们试了一下4A4W的LLAMA2,为什么即使在8bit以下,如果打开半精度训练,损失依然是nan?另外,当我们使用全精度训练时,4A16W的wikitext2 ppl是11.,而不是论文中的5.,我们想确认一下现在代码的版本是否可以正确运行LLAMA2,感谢!

ChenMnZ commented 6 months ago

@Xingrun-Xing 我刚跑了一遍LLaMa-2-7b w4a16,现在的代码版本可以正确运行LLAMA2,你可以提供你的训练指令以进行进一步查看。 训练代码:

CUDA_VISIBLE_DEVICES=1 python main.py \
--model path/to/llama-7b  \
--output_dir ./log/test \
--epochs 5 --nsamples 128 \
--wbits 4 --abits 16 --lwc --eval_ppl

结果: image

以及需要注意的是,OmniQuant会读取保存好的cache数据集,你需要保证这些cache数据集没问题,若是可能存在问题你可以把cache文件夹下的文件删除重新跑一篇OmniQuant代码。

Xingrun-Xing commented 6 months ago

@Xingrun-Xing 我刚跑了一遍LLaMa-2-7b w4a16,现在的代码版本可以正确运行LLAMA2,你可以提供你的训练指令以进行进一步查看。 训练代码:

CUDA_VISIBLE_DEVICES=1 python main.py \
--model path/to/llama-7b  \
--output_dir ./log/test \
--epochs 5 --nsamples 128 \
--wbits 4 --abits 16 --lwc --eval_ppl

结果: image

以及需要注意的是,OmniQuant会读取保存好的cache数据集,你需要保证这些cache数据集没问题,若是可能存在问题你可以把cache文件夹下的文件删除重新跑一篇OmniQuant代码。

感谢回复 我刚才重新拉现在的代码又跑了一遍,用的和你一样的命令,结果已经正常了。另外有一个小问题,我们的跑的结果略有差异,可能的原因是什么呢,请问你的transformers版本是多少 image

Xingrun-Xing commented 6 months ago

你好,发现另有一个问题,我们试了一下/script下的各个脚本,发现W6A6不能正常训练 LLAMA2,其他脚本都正常 image

ChenMnZ commented 6 months ago

@Xingrun-Xing

  1. transformer version 4.36.0。结果差异可以尝试下载训练好的量化参数进行测试,排查是否环境存在问题。若是环境没问题,训练好的量化参数应该能得到和论文中差不多的结果。
  2. W6A6遇到NAN尝试加入--deactive_amp应该可解决
haoming-codes commented 4 months ago
  1. W6A6遇到NAN尝试加入--deactive_amp应该可解决

May I ask if you know why torch.amp does not work with W6A6? Thank you for the great work.