InternLM / lmdeploy

LMDeploy is a toolkit for compressing, deploying, and serving LLMs.
https://lmdeploy.readthedocs.io/en/latest/
Apache License 2.0
4.16k stars 376 forks source link

[Bug] Why does prefix caching change the generated content #1719

Open DayDayupupupup opened 3 months ago

DayDayupupupup commented 3 months ago

Checklist

Describe the bug

Model: internlm2-chat-7b GPU: A30 VERSION:0.4.2

When enable_prefix_caching=True, the generated content is different from enable_prefix_caching=False

Reproduction

test script: internlm.py

import argparse
import time
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.model import ChatTemplateConfig
# Sample prompts.
prompts = """你是一个名为"EVA"人工智能助手,正在与人类用户进行交谈。你的目标是以最有帮助和最逻辑的方式回答问题,同时确保内容的安全性。你的回答中不应包含任何有害、政治化、宗教化、不道德、种族主义、非法的内容。请确保你的回答不带有社会偏见,符合社会主义价值观。如果遇到的问题无意义或事实上不连贯,请不要回答错误的内容,而是解释问题为何无效或不连贯。如果你不知道问题的答案,也请勿提供错误的信息。对公转账怎么操作?"""

def profile(args):
    model_path = args.model
    engine_config = TurbomindEngineConfig(quant_policy=args.kv_cache,
                                          tp=1,
                                          enable_prefix_caching=args.enable_prefix_caching,
                                          session_len=4096)
    pipe = pipeline(model_path=model_path,
                    backend_config=engine_config,
                    log_level='ERROR') 
    gen_config = GenerationConfig(max_new_tokens=args.output_len, 
                                  temperature=0.0,
                                  top_p=1.0,
                                  top_k=1,
                                  ignore_eos=False,
                                  stop_words=['<|im_end|>','</s>'],
                                  skip_special_tokens=False)
    # print(gen_config)

    for i in range(3):
        start_time = time.perf_counter()
        outputs = pipe([prompts], gen_config=gen_config,do_preprocess=False)
        end_time = time.perf_counter()
        latency = end_time - start_time
        print(f'---------------------------- Infer {i} -----------------------------------')
        print(f'Infer {i} elapsed : {latency*1000:.2f} ms')
        print(f"Generated text: {outputs[0].text!r}")
        print(f'prompt_tokens={outputs[0].input_token_len} gen_tokens={outputs[0].generate_token_len}')

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model',
                        type=str,
                        default='LLM model path',
                        help='')
    parser.add_argument('-i', '--input_len', type=int, default=1024)
    parser.add_argument('-o', '--output_len', type=int, default=512)
    parser.add_argument('--kv_cache', type=int, default=0, choices=[0,4,8])
    parser.add_argument('--enable_prefix_caching', action="store_true")
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_arguments()
    profile(args)  

TEST 1 Disable prefix caching

python internlm.py -m internlm2-chat-7b
All three requests generate exactly the same content.

[WARNING] gemm_config.in is not found; using default GEMM algo
---------------------------- Infer 0 -----------------------------------
Infer 0 elapsed : 10202.76 ms
Generated text: '\n以下是一些关于公转账操作的常见问题和解答:\n\n### 什么是公转账?\n\n公转账是指将资金从一个银行账户转移到另一个银行账户的过程,通常用于商业交易、支付账单或进行投资。\n\n### 如何进行公转账?\n\n以下是公转账的一般步骤:\n\n1. **确认账户信息**:\n   - 确保你有正确的收款人姓名、账户号码和银行信息。\n\n2. **登录银行账户**:\n   - 登录你的银行账户,进入转账页面。\n\n3. **填写转账信息**:\n   - 输入收款人的账户信息,包括姓名、账户号码和银行代码。\n   - 输入转账金额。\n   - 确认转账详情。\n\n4. **验证身份**:\n   - 根据银行要求,可能需要提供身份验证信息,如密码、验证码或指纹识别。\n\n5. **提交转账**:\n   - 点击“提交”或“确认”按钮,完成转账。\n\n6. **等待处理**:\n   - 银行通常会在一定时间内处理转账请求。\n\n7. **确认转账**:\n   - 转账完成后,你可以在银行账户中查看转账记录。\n\n### 公转账需要注意什么?\n\n- **准确性**:确保所有信息(如账户号码、姓名和金额)都是准确的。\n- **安全性**:保护你的账户信息,避免在公共场所或不安全的网络环境下进行转账。\n- **费用**:了解银行可能收取的转账费用,并考虑这些费用对你的财务影响。\n- **时效性**:了解转账的处理时间,并计划好转账时间以确保资金在需要时到达。\n\n### 公转账的常见问题:\n\n- **转账延迟**:如果转账延迟,可能是因为银行处理时间、节假日或系统问题。\n- **转账失败**:如果转账失败,可能是因为输入错误的信息、账户余额不足或银行系统问题。\n- **资金安全**:确保你的转账信息安全,避免在公共网络或不受信任的网站上进行转账。\n\n### 总结:\n\n公转账是银行账户间资金转移的一种方式,通常用于商业交易和支付。进行公转账时,请确保信息准确、账户安全,并了解可能涉及的费用和处理时间。如果你遇到问题,可以联系你的银行客服寻求帮助。\n\n请注意,以上信息基于一般情况,具体的公转账操作可能因银行和地区而异。在实际操作前,请查阅你'
prompt_tokens=111 gen_tokens=513
---------------------------- Infer 1 -----------------------------------
Infer 1 elapsed : 10118.08 ms
Generated text: '\n以下是一些关于公转账操作的常见问题和解答:\n\n### 什么是公转账?\n\n公转账是指将资金从一个银行账户转移到另一个银行账户的过程,通常用于商业交易、支付账单或进行投资。\n\n### 如何进行公转账?\n\n以下是公转账的一般步骤:\n\n1. **确认账户信息**:\n   - 确保你有正确的收款人姓名、账户号码和银行信息。\n\n2. **登录银行账户**:\n   - 登录你的银行账户,进入转账页面。\n\n3. **填写转账信息**:\n   - 输入收款人的账户信息,包括姓名、账户号码和银行代码。\n   - 输入转账金额。\n   - 确认转账详情。\n\n4. **验证身份**:\n   - 根据银行要求,可能需要提供身份验证信息,如密码、验证码或指纹识别。\n\n5. **提交转账**:\n   - 点击“提交”或“确认”按钮,完成转账。\n\n6. **等待处理**:\n   - 银行通常会在一定时间内处理转账请求。\n\n7. **确认转账**:\n   - 转账完成后,你可以在银行账户中查看转账记录。\n\n### 公转账需要注意什么?\n\n- **准确性**:确保所有信息(如账户号码、姓名和金额)都是准确的。\n- **安全性**:保护你的账户信息,避免在公共场所或不安全的网络环境下进行转账。\n- **费用**:了解银行可能收取的转账费用,并考虑这些费用对你的财务影响。\n- **时效性**:了解转账的处理时间,并计划好转账时间以确保资金在需要时到达。\n\n### 公转账的常见问题:\n\n- **转账延迟**:如果转账延迟,可能是因为银行处理时间、节假日或系统问题。\n- **转账失败**:如果转账失败,可能是因为输入错误的信息、账户余额不足或银行系统问题。\n- **资金安全**:确保你的转账信息安全,避免在公共网络或不受信任的网站上进行转账。\n\n### 总结:\n\n公转账是银行账户间资金转移的一种方式,通常用于商业交易和支付。进行公转账时,请确保信息准确、账户安全,并了解可能涉及的费用和处理时间。如果你遇到问题,可以联系你的银行客服寻求帮助。\n\n请注意,以上信息基于一般情况,具体的公转账操作可能因银行和地区而异。在实际操作前,请查阅你'
prompt_tokens=111 gen_tokens=513
---------------------------- Infer 2 -----------------------------------
Infer 2 elapsed : 10118.81 ms
Generated text: '\n以下是一些关于公转账操作的常见问题和解答:\n\n### 什么是公转账?\n\n公转账是指将资金从一个银行账户转移到另一个银行账户的过程,通常用于商业交易、支付账单或进行投资。\n\n### 如何进行公转账?\n\n以下是公转账的一般步骤:\n\n1. **确认账户信息**:\n   - 确保你有正确的收款人姓名、账户号码和银行信息。\n\n2. **登录银行账户**:\n   - 登录你的银行账户,进入转账页面。\n\n3. **填写转账信息**:\n   - 输入收款人的账户信息,包括姓名、账户号码和银行代码。\n   - 输入转账金额。\n   - 确认转账详情。\n\n4. **验证身份**:\n   - 根据银行要求,可能需要提供身份验证信息,如密码、验证码或指纹识别。\n\n5. **提交转账**:\n   - 点击“提交”或“确认”按钮,完成转账。\n\n6. **等待处理**:\n   - 银行通常会在一定时间内处理转账请求。\n\n7. **确认转账**:\n   - 转账完成后,你可以在银行账户中查看转账记录。\n\n### 公转账需要注意什么?\n\n- **准确性**:确保所有信息(如账户号码、姓名和金额)都是准确的。\n- **安全性**:保护你的账户信息,避免在公共场所或不安全的网络环境下进行转账。\n- **费用**:了解银行可能收取的转账费用,并考虑这些费用对你的财务影响。\n- **时效性**:了解转账的处理时间,并计划好转账时间以确保资金在需要时到达。\n\n### 公转账的常见问题:\n\n- **转账延迟**:如果转账延迟,可能是因为银行处理时间、节假日或系统问题。\n- **转账失败**:如果转账失败,可能是因为输入错误的信息、账户余额不足或银行系统问题。\n- **资金安全**:确保你的转账信息安全,避免在公共网络或不受信任的网站上进行转账。\n\n### 总结:\n\n公转账是银行账户间资金转移的一种方式,通常用于商业交易和支付。进行公转账时,请确保信息准确、账户安全,并了解可能涉及的费用和处理时间。如果你遇到问题,可以联系你的银行客服寻求帮助。\n\n请注意,以上信息基于一般情况,具体的公转账操作可能因银行和地区而异。在实际操作前,请查阅你'
prompt_tokens=111 gen_tokens=513

TEST 2 Enable prefix caching

python internlm.py -m internlm2-chat-7b --enable_prefix_caching
The generated content of the 2.3 request is different from that of the first request.

[WARNING] gemm_config.in is not found; using default GEMM algo
---------------------------- Infer 0 -----------------------------------
Infer 0 elapsed : 10185.02 ms
Generated text: '\n以下是一些关于公转账操作的常见问题和解答:\n\n### 什么是公转账?\n\n公转账是指将资金从一个银行账户转移到另一个银行账户的过程,通常用于商业交易、支付账单或进行投资。\n\n### 如何进行公转账?\n\n以下是公转账的一般步骤:\n\n1. **确认账户信息**:\n   - 确保你有正确的收款人姓名、账户号码和银行信息。\n\n2. **登录银行账户**:\n   - 登录你的银行账户,进入转账页面。\n\n3. **填写转账信息**:\n   - 输入收款人的账户信息,包括姓名、账户号码和银行代码。\n   - 输入转账金额。\n   - 确认转账详情。\n\n4. **验证身份**:\n   - 根据银行要求,可能需要提供身份验证信息,如密码、验证码或指纹识别。\n\n5. **提交转账**:\n   - 点击“提交”或“确认”按钮,完成转账。\n\n6. **等待处理**:\n   - 银行通常会在一定时间内处理转账请求。\n\n7. **确认转账**:\n   - 转账完成后,你可以在银行账户中查看转账记录。\n\n### 公转账需要注意什么?\n\n- **准确性**:确保所有信息(如账户号码、姓名和金额)都是准确的。\n- **安全性**:保护你的账户信息,避免在公共场所或不安全的网络环境下进行转账。\n- **费用**:了解银行可能收取的转账费用,并考虑这些费用对你的财务影响。\n- **时效性**:了解转账的处理时间,并计划好转账时间以确保资金在需要时到达。\n\n### 公转账的常见问题:\n\n- **转账延迟**:如果转账延迟,可能是因为银行处理时间、节假日或系统问题。\n- **转账失败**:如果转账失败,可能是因为输入错误的信息、账户余额不足或银行系统问题。\n- **资金安全**:确保你的转账信息安全,避免在公共网络或不受信任的网站上进行转账。\n\n### 总结:\n\n公转账是银行账户间资金转移的一种方式,通常用于商业交易和支付。进行公转账时,请确保信息准确、账户安全,并了解可能涉及的费用和处理时间。如果你遇到问题,可以联系你的银行客服寻求帮助。\n\n请注意,以上信息基于一般情况,具体的公转账操作可能因银行和地区而异。在实际操作前,请查阅你'
prompt_tokens=111 gen_tokens=513
---------------------------- Infer 1 -----------------------------------
Infer 1 elapsed : 8556.29 ms
Generated text: '\n以下是一些关于公转账操作的常见问题和解答:\n\n### 什么是公转账?\n\n公转账是指将资金从一个银行账户转移到另一个银行账户的过程,通常用于商业交易、支付账单或进行投资。\n\n### 如何进行公转账?\n\n以下是公转账的一般步骤:\n\n1. **确认账户信息**:\n   - 确保你有正确的收款人姓名、账户号码和银行信息。\n\n2. **登录银行账户**:\n   - 登录你的银行账户,进入转账页面。\n\n3. **填写转账信息**:\n   - 输入收款人的账户信息,包括姓名、账户号码和银行代码。\n   - 输入转账金额。\n   - 选择转账类型(如普通转账、定期转账等)。\n\n4. **验证转账**:\n   - 确认所有信息无误后,提交转账请求。\n\n5. **等待处理**:\n   - 银行处理转账请求,通常需要一定时间。\n\n6. **确认转账**:\n   - 一旦转账完成,你将收到确认信息。\n\n### 转账费用\n\n- **手续费**:大多数银行会对每笔转账收取一定的手续费。\n- **转账限额**:有些银行可能对单笔或每日转账金额有限制。\n\n### 安全注意事项\n\n- **保护个人信息**:确保你的银行账户和密码安全,避免在公共网络或不安全设备上进行转账操作。\n- **验证收款人信息**:在转账前,务必核实收款人的身份和账户信息,以防止诈骗。\n\n### 常见问题\n\n- **转账延迟**:银行处理转账需要时间,通常1-3个工作日。\n- **转账失败**:如果转账失败,可能是因为账户信息错误或账户余额不足。\n- **转账撤销**:如果你需要撤销转账,请尽快联系银行。\n\n### 结论\n\n公转账是进行商业交易和支付的重要方式。在进行公转账时,请确保提供正确的账户信息,并注意转账费用和安全问题。如果你遇到任何问题,请咨询你的银行或使用银行提供的客户服务。'
prompt_tokens=111 gen_tokens=435
---------------------------- Infer 2 -----------------------------------
Infer 2 elapsed : 8555.07 ms
Generated text: '\n以下是一些关于公转账操作的常见问题和解答:\n\n### 什么是公转账?\n\n公转账是指将资金从一个银行账户转移到另一个银行账户的过程,通常用于商业交易、支付账单或进行投资。\n\n### 如何进行公转账?\n\n以下是公转账的一般步骤:\n\n1. **确认账户信息**:\n   - 确保你有正确的收款人姓名、账户号码和银行信息。\n\n2. **登录银行账户**:\n   - 登录你的银行账户,进入转账页面。\n\n3. **填写转账信息**:\n   - 输入收款人的账户信息,包括姓名、账户号码和银行代码。\n   - 输入转账金额。\n   - 选择转账类型(如普通转账、定期转账等)。\n\n4. **验证转账**:\n   - 确认所有信息无误后,提交转账请求。\n\n5. **等待处理**:\n   - 银行处理转账请求,通常需要一定时间。\n\n6. **确认转账**:\n   - 一旦转账完成,你将收到确认信息。\n\n### 转账费用\n\n- **手续费**:大多数银行会对每笔转账收取一定的手续费。\n- **转账限额**:有些银行可能对单笔或每日转账金额有限制。\n\n### 安全注意事项\n\n- **保护个人信息**:确保你的银行账户和密码安全,避免在公共网络或不安全设备上进行转账操作。\n- **验证收款人信息**:在转账前,务必核实收款人的身份和账户信息,以防止诈骗。\n\n### 常见问题\n\n- **转账延迟**:银行处理转账需要时间,通常1-3个工作日。\n- **转账失败**:如果转账失败,可能是因为账户信息错误或账户余额不足。\n- **转账撤销**:如果你需要撤销转账,请尽快联系银行。\n\n### 结论\n\n公转账是进行商业交易和支付的重要方式。在进行公转账时,请确保提供正确的账户信息,并注意转账费用和安全问题。如果你遇到任何问题,请咨询你的银行或使用银行提供的客户服务。'
prompt_tokens=111 gen_tokens=435

Environment

sys.platform: linux
Python: 3.8.10 (default, Nov 22 2023, 10:22:35) [GCC 9.4.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3: NVIDIA A30
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.8, V11.8.89
GCC: x86_64-linux-gnu-gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
PyTorch: 2.1.0+cu118
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,

TorchVision: 0.16.0+cu118
LMDeploy: 0.4.2+
transformers: 4.41.1
gradio: 3.50.2
fastapi: 0.111.0
pydantic: 2.7.1
triton: 2.1.0

Error traceback

No response

zhyncs commented 3 months ago

We will verify it locally and update you on any progress.

ispobock commented 3 months ago

It seems cannot be reproduced in llama2 (llama2_13B_chat) model, but can be reproduced in llama3, internlm2. So I guess this issue is for GQA models.

zhyncs commented 3 months ago

A workaround is to change bfloat16 to float16 in config.json. @DayDayupupupup cc @lvhan028 This bug was not introduced by the prefix cache from the root cause. @lzhangzz may take a look.

zhyncs commented 3 months ago

A workaround is to change bfloat16 to float16 in config.json.

https://huggingface.co/internlm/internlm2-chat-7b/blob/main/config.json#L28

ispobock commented 3 months ago

It seems cannot be reproduced in llama2 (llama2_13B_chat) model, but can be reproduced in llama3, internlm2. So I guess this issue is for GQA models.

The torch_dtype for llama3 and internlm2 is bfloat16, but for llama2 is float16. When changing the torch_dtype to float16 for internlm2 and llama3, the result is the same for enabling/disabling prefix caching.

ispobock commented 3 months ago

We print part of the KV cache values in each block to debug it: https://github.com/InternLM/lmdeploy/blob/9fd9c8c8b753db161f186b614f9d0e5688dd64ef/src/turbomind/models/llama/LlamaBatch.cc#L577

for (int i = 0; i < seq.blocks.size(); i++) {
    std::vector<half> v(20);
    Copy(static_cast<half*>(sequence_manager_->GetBlockPtr(seq.blocks[i])), 20, v.data());
    for (int k = 0; k < 20; k++) {
        std::cout << __half2float(v[k]) << " ";
    }
    std::cout << ", ";
}

We find some small values diff (probably caused by precision conversion) in the block after the cached blocks.

zhyncs commented 3 months ago

Discussion about float16 and bfloat16 can be found at https://github.com/InternLM/lmdeploy/pull/1140#issuecomment-1931703194. Currently the issue is caused by precision problems. The reused block KV cache value is consistent. When the type is bfloat16, inconsistencies in precision have emerged in the following generated token.

DayDayupupupup commented 3 months ago

https://huggingface.co/internlm/internlm2-chat-7b/commit/5b50661e5ba16c9ded1047a51e394280b3b9bda1 I have confirmed that the internlm2-chat-7b l I am using is not the latest version, but the above version. The default torch_dtype is float16.

zhyncs commented 3 months ago

A workaround is to change bfloat16 to float16 in config.json.

https://huggingface.co/internlm/internlm2-chat-7b/blob/main/config.json#L28

@DayDayupupupup May you try the latest version in this way

fiona-lxd commented 3 months ago

I have also got a different answer for internlm-xcomposer. When using the official code of internlm-xcompose, the model behaves correctly. However, it cannot output a same answer for lmdeploy. Changing bfloat16 to float16 doesn't help BTW.

zhyncs commented 3 months ago

I have also got a different answer

Are you referring to this https://github.com/InternLM/lmdeploy/issues/1688

DayDayupupupup commented 3 months ago

A workaround is to change bfloat16 to float16 in config.json.

https://huggingface.co/internlm/internlm2-chat-7b/blob/main/config.json#L28

@DayDayupupupup May you try the latest version in this way

Using the latest version(commit [3e6b81c]), and changing bf16 to f16. Enable prefix caching, the request results are consistent.

So why the old version with fp16 weight is not working?

lzhangzz commented 3 months ago

When prefix caching is enabled, cached part of the prompt will not be prefilled again. This leads to different GEMM problem size and the dispatched kernel may be different.

When doing GEMM, different concurrency level in the k-mode leads to different accumulation order and thus different floating point outcome.

zhyncs commented 3 months ago

This leads to different GEMM problem size and the dispatched kernel may be different.

This https://github.com/InternLM/lmdeploy/issues/1719#issuecomment-2152540734 may not be explained

zhyncs commented 3 months ago

ref https://github.com/InternLM/lmdeploy/issues/1688#issuecomment-2141305478

zhyncs commented 3 months ago

To summarize, there are several scenarios where using temperature 0 results in output differences:

  1. As split-kv is taking effect automatically, variable batch size and sequence length at runtime may result in different split-kv factor. This will lead to differnt accumulation order and thus differnt outcome. https://github.com/InternLM/lmdeploy/issues/1688#issuecomment-2141305478
  2. When prefix caching is enabled, cached part of the prompt will not be prefilled again. This leads to different GEMM problem size and the dispatched kernel may be different. https://github.com/InternLM/lmdeploy/issues/1719#issuecomment-2153872123
  3. When a request is Partial and Not Partial. Similar to 2, the dispatched kernel may be different.

Is there currently a plan to address this issue? In some scenarios, such as generative search, the temperature is usually set very low or even to 0. For instance, when it's at 0, if algorithm engineers find that the results are inconsistent with those from transformers, it can be quite perplexing. @lvhan028 @lzhangzz