Open DayDayupupupup opened 3 months ago
We will verify it locally and update you on any progress.
It seems cannot be reproduced in llama2 (llama2_13B_chat) model, but can be reproduced in llama3, internlm2. So I guess this issue is for GQA models.
A workaround is to change bfloat16
to float16
in config.json. @DayDayupupupup cc @lvhan028
This bug was not introduced by the prefix cache from the root cause. @lzhangzz may take a look.
A workaround is to change
bfloat16
tofloat16
in config.json.
https://huggingface.co/internlm/internlm2-chat-7b/blob/main/config.json#L28
It seems cannot be reproduced in llama2 (llama2_13B_chat) model, but can be reproduced in llama3, internlm2. So I guess this issue is for GQA models.
The torch_dtype
for llama3 and internlm2 is bfloat16
, but for llama2 is float16
.
When changing the torch_dtype
to float16
for internlm2 and llama3, the result is the same for enabling/disabling prefix caching.
We print part of the KV cache values in each block to debug it: https://github.com/InternLM/lmdeploy/blob/9fd9c8c8b753db161f186b614f9d0e5688dd64ef/src/turbomind/models/llama/LlamaBatch.cc#L577
for (int i = 0; i < seq.blocks.size(); i++) {
std::vector<half> v(20);
Copy(static_cast<half*>(sequence_manager_->GetBlockPtr(seq.blocks[i])), 20, v.data());
for (int k = 0; k < 20; k++) {
std::cout << __half2float(v[k]) << " ";
}
std::cout << ", ";
}
We find some small values diff (probably caused by precision conversion) in the block after the cached blocks.
Discussion about float16 and bfloat16 can be found at https://github.com/InternLM/lmdeploy/pull/1140#issuecomment-1931703194. Currently the issue is caused by precision problems. The reused block KV cache value is consistent. When the type is bfloat16, inconsistencies in precision have emerged in the following generated token.
https://huggingface.co/internlm/internlm2-chat-7b/commit/5b50661e5ba16c9ded1047a51e394280b3b9bda1 I have confirmed that the internlm2-chat-7b l I am using is not the latest version, but the above version. The default torch_dtype is float16.
A workaround is to change
bfloat16
tofloat16
in config.json.https://huggingface.co/internlm/internlm2-chat-7b/blob/main/config.json#L28
@DayDayupupupup May you try the latest version in this way
I have also got a different answer for internlm-xcomposer. When using the official code of internlm-xcompose, the model behaves correctly. However, it cannot output a same answer for lmdeploy. Changing bfloat16 to float16 doesn't help BTW.
I have also got a different answer
Are you referring to this https://github.com/InternLM/lmdeploy/issues/1688
A workaround is to change
bfloat16
tofloat16
in config.json.https://huggingface.co/internlm/internlm2-chat-7b/blob/main/config.json#L28
@DayDayupupupup May you try the latest version in this way
Using the latest version(commit [3e6b81c]), and changing bf16 to f16. Enable prefix caching, the request results are consistent.
So why the old version with fp16 weight is not working?
When prefix caching is enabled, cached part of the prompt will not be prefilled again. This leads to different GEMM problem size and the dispatched kernel may be different.
When doing GEMM, different concurrency level in the k-mode leads to different accumulation order and thus different floating point outcome.
This leads to different GEMM problem size and the dispatched kernel may be different.
This https://github.com/InternLM/lmdeploy/issues/1719#issuecomment-2152540734 may not be explained
To summarize, there are several scenarios where using temperature 0 results in output differences:
Is there currently a plan to address this issue? In some scenarios, such as generative search, the temperature is usually set very low or even to 0. For instance, when it's at 0, if algorithm engineers find that the results are inconsistent with those from transformers, it can be quite perplexing. @lvhan028 @lzhangzz
Checklist
Describe the bug
Model: internlm2-chat-7b GPU: A30 VERSION:0.4.2
When
enable_prefix_caching=True
, the generated content is different fromenable_prefix_caching=False
Reproduction
test script: internlm.py
TEST 1 Disable prefix caching
python internlm.py -m internlm2-chat-7b
All three requests generate exactly the same content.
TEST 2 Enable prefix caching
python internlm.py -m internlm2-chat-7b --enable_prefix_caching
The generated content of the 2.3 request is different from that of the first request.
Environment
Error traceback
No response