InternLM / lmdeploy

LMDeploy is a toolkit for compressing, deploying, and serving LLMs.
https://lmdeploy.readthedocs.io/en/latest/
Apache License 2.0
3.11k stars 280 forks source link

[Feature] update the range of torch versions #1857

Open zhyncs opened 3 days ago

zhyncs commented 3 days ago

Motivation

current range https://github.com/InternLM/lmdeploy/blob/a06174f836882d853d4eb18519c2245c2a7eae8c/requirements/runtime.txt#L16

vLLM latest requirement

torch == 2.3.0

https://github.com/vllm-project/vllm/blob/515080ad2fd93cc8e363ff43b90a9df18cfd71ff/requirements-cuda.txt#L7

In order to install vLLM and LMDeploy in the same image, I upgraded the torch version to 2.3.0 and used the --no-deps parameter when installing LMDeploy.

In order to verify the impact of upgrading the torch version on the performance of LMDeploy PyTorch Engine, I conducted a simple benchmark.

From the results, it can be seen that after updating to torch 2.3.0, the performance of PyTorch Engine is still within a reasonable range.

May we considered expanding the version range of torch in LMDeploy to 2.3.0? @grimoire @lvhan028

# server
# python3 -m lmdeploy serve api_server /workdir/Llama-2-13b-chat-hf
# python3 -m lmdeploy serve api_server /workdir/Llama-2-13b-chat-hf --backend pytorch
# python3 -m vllm.entrypoints.openai.api_server --model /workdir/Llama-2-13b-chat-hf

# client
# ignore_eos true
# https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py
# python3 benchmark_serving.py --backend lmdeploy --host 127.0.0.1 --port 23333 --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model /workdir/Llama-2-13b-chat-hf --num-prompts 1000 --request-rate 128
# python3 benchmark_serving.py --backend vllm --host 127.0.0.1 --port 8000 --dataset /workdir/ShareGPT_V3_unfiltered_cleaned_split.json --model /workdir/Llama-2-13b-chat-hf --num-prompts 1000 --request-rate 128

# LMDeploy TurboMind
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  131.50
Total input tokens:                      245995
Total generated tokens:                  236921
Request throughput (req/s):              7.60
Input token throughput (tok/s):          1870.72
Output token throughput (tok/s):         1801.72
---------------Time to First Token----------------
Mean TTFT (ms):                          42688.77
Median TTFT (ms):                        39322.31
P99 TTFT (ms):                           100236.83
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          66.10
Median TPOT (ms):                        61.97
P99 TPOT (ms):                           222.96
---------------Inter-token Latency----------------
Mean ITL (ms):                           61.03
Median ITL (ms):                         48.83
P99 ITL (ms):                            205.38
==================================================

# LMDeploy PyTorch Engine
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  191.57
Total input tokens:                      245995
Total generated tokens:                  234774
Request throughput (req/s):              5.22
Input token throughput (tok/s):          1284.11
Output token throughput (tok/s):         1225.53
---------------Time to First Token----------------
Mean TTFT (ms):                          76415.95
Median TTFT (ms):                        75958.52
P99 TTFT (ms):                           151205.85
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          88.28
Median TPOT (ms):                        79.38
P99 TPOT (ms):                           224.32
---------------Inter-token Latency----------------
Mean ITL (ms):                           73.75
Median ITL (ms):                         57.45
P99 ITL (ms):                            450.47
==================================================

# vLLM
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  211.53
Total input tokens:                      245995
Total generated tokens:                  235482
Request throughput (req/s):              4.73
Input token throughput (tok/s):          1162.90
Output token throughput (tok/s):         1113.21
---------------Time to First Token----------------
Mean TTFT (ms):                          71490.33
Median TTFT (ms):                        76263.93
P99 TTFT (ms):                           165901.91
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          164.11
Median TPOT (ms):                        134.97
P99 TPOT (ms):                           718.64
---------------Inter-token Latency----------------
Mean ITL (ms):                           135.97
Median ITL (ms):                         99.74
P99 ITL (ms):                            540.58
==================================================

env

sys.platform: linux
Python: 3.9.16 (main, Aug 15 2023, 19:38:56) [GCC 8.3.1 20190311 (Red Hat 8.3.1-3)]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1: NVIDIA A100-SXM4-80GB
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.8, V11.8.89
GCC: gcc (GCC) 10.2.1 20210130 (Red Hat 10.2.1-11)
PyTorch: 2.3.0+cu118
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.3.6 (Git Hash 86e6af5974177e513fd3fee58425e1063e7f1361)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.2  (built against CUDA 12.1)
    - Built with CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.3.0, USE_CUDA=ON, USE_CUDNN=ON, USE_CUSPARSELT=1, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_GLOO=ON, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF,

LMDeploy: 0.4.2+
transformers: 4.40.0
gradio: 3.50.2
fastapi: 0.110.3
pydantic: 2.6.0
triton: 2.3.0

Related resources

No response

Additional context

No response

zhyncs commented 3 days ago

nit: We may also need to upgrade flash attention when we use torch 2.3.0

https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.9.post1

lvhan028 commented 3 days ago

@grimoire said torch2.3.0 + triton 2.3.0 degrade the performance speed comparing the torch2.2.2 + triton 2.2.0. PR #1499 shows the speed decreases about 8% in throughput if torch is upgraded to 2.3.0 That's why I decided not to upgrade it.

zhyncs commented 3 days ago

torch2.3.0 + triton 2.3.0 degrade the performance speed comparing the torch2.2.2 + triton 2.2.0

Okay, I'll take a look to verify and see if this issue still exists in the latest main branch code.

grimoire commented 3 days ago

triton 2.3.0 takes more time to perform the kernel launch (check device/stream, generate cache key, etc). Models with more GPU computation might suffer less from it.