ztxz16 / fastllm

纯c++的全平台llm加速库,支持python调用,chatglm-6B级模型单卡可达10000+token / s,支持glm, llama, moss基座,手机端流畅运行
Apache License 2.0
3.28k stars 332 forks source link

加速chatglm2感觉没效果,和pytho直接调用都差不多是 30ms/token #369

Open 17714196157 opened 10 months ago

17714196157 commented 10 months ago

python代码直接加载模型调用 image

Explicitly passing a revision is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision. Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:32<00:00, 4.67s/it] 6.765336513519287, responses=你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。 0.16500820764681187s/token 1.5385632514953613, responses=你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。 0.0375259329633015s/token 1.5410957336425781, responses=你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。 0.037587700820550685s/token 1.5482840538024902, responses=你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。 0.03776302570249976s/token 1.5544860363006592, responses=你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。 0.03791429356830876s/token

导出flm模型

python3 tools/chatglm_export.py chatglm2-6b-int4.flm int4

image

修改脚本 ,统计生成速度

from fastllm_pytools import llm import time def args_parser(): parser = argparse.ArgumentParser(description = 'fastllm_chat_demo') parser.add_argument('-p', '--path', type = str, required = True, default = '', help = '模型文件的路径') args = parser.parse_args() return args

if name == "main": args = args_parser() model = llm.model(args.path)

history = []
print("输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
while True:
    query = input("\n用户:")
    if query.strip() == "stop":
        break
    if query.strip() == "clear":
        history = []
        print("输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
        continue
    print("AI:", end = "");
    curResponse = "";
    t1= time.time()
    for response in model.stream_response(query, history = history):
        curResponse += response;
        print(response, flush = True, end = "")
    t2 = time.time()
    len_n = len(curResponse)
    print(f"{t2 - t1}, responses={curResponse} {(t2 - t1) / len_n}s/token") # 0.03889209468190263s/token
    history.append((query, curResponse))
xiaoduozhou commented 9 months ago

有效果的, 我实验过差不多节省了一半的时间