InternLM / lmdeploy

LMDeploy is a toolkit for compressing, deploying, and serving LLMs.
https://lmdeploy.readthedocs.io/en/latest/
Apache License 2.0
3.15k stars 281 forks source link

[Bug] 不支持qwen0.5b的加速?以及qwen0.5b的awq量化? #1870

Open qism opened 6 days ago

qism commented 6 days ago

Checklist

Describe the bug

是否不支持qwen0.5b的加速?以及qwen0.5b的awq量化? qwen0.5b 单卡t4 推理延时如下: vllm:1.3s lmdeploy:3.2s

Reproduction

from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig backend_config = TurbomindEngineConfig(cache_max_entry_count=0.9) gen_config = GenerationConfig(top_p=0.95, temperature=0.001, max_new_tokens=512)

llm = pipeline(model_path, backend_config=backend_config)

arrival_time = time.time() llm_results = llm(prompts, gen_config=gen_config) finished_time = time.time()

Environment

lmdeploy 0.4.2

Error traceback

No response

zhyncs commented 2 days ago

Qwen2 0.5b is supported with PyTorch Engine. I tested it in the local env and encountered the performance issue compared with vLLM. @grimoire may you please take a look.

grimoire commented 2 days ago

Turbomind does not support Qwen2<=1.8b. And AWQ for the pytorch engine is WIP. The problem is that Qwen2 0.5b doesn't have enough GPU computation to hide the kernel launch overhead. CUDAGraph might be a way, but all models and kernels need to be redesigned to support it. Not to mention that we also plan to support non-nvidia devices.

zhyncs commented 2 days ago

Turbomind does not support Qwen2<=1.8b. And AWQ for the pytorch engine is WIP. The problem is that Qwen2 0.5b doesn't have enough GPU computation to hide the kernel launch overhead. CUDAGraph might be a way, but all models and kernels need to be redesigned to support it. Not to mention that we also plan to support non-nvidia devices.

make sense ref https://github.com/InternLM/lmdeploy/pull/1499#issuecomment-2084757953 And I am really looking forward to running the PyTorch Engine on AMD GPU.

zhyncs commented 2 days ago

@grimoire By the way, is the priority supported by CUDA graph lower than that of torch.compile? The latter is currently the main recommended optimization method for PyTorch native.

grimoire commented 2 days ago

Pytorch does not support using custom triton kernel in torch.compile before 2.3.0. I will do some investigation on this.