InternLM / lmdeploy

LMDeploy is a toolkit for compressing, deploying, and serving LLMs.
https://lmdeploy.readthedocs.io/en/latest/
Apache License 2.0
4.53k stars 410 forks source link

[Bug] 对InternVL模型进行推理时,图像编码阶段gpu-cpu的传输时间过长 #2624

Open Dimensionzw opened 1 week ago

Dimensionzw commented 1 week ago

Checklist

Describe the bug

使用lmdeploy对internvl-26b模型进行推理,GPU型号为NVIDIA L20,首token时延达到了2s多,通过对各阶段进行分析,发现主要时延存在于GPU-CPU的拷贝阶段,代码位置在lmdeploy/vl/engine.py中

def forward(self, inputs: List[Image], params: List[Dict] = None):
        """Model forward."""
        params = self._init_input_params(inputs, params)
        time_start = time.perf_counter()
        start_time = time.time()
        func_params = inspect.signature(self.model.forward).parameters
        func_inputs = [inputs, params] if len(func_params) > 1 else [inputs]
        pre_time = round((time.time()-start_time)*1000,2)
        outputs = self.model.forward(*func_inputs)
        forward_time = round((time.time()-start_time)*1000,2)
        torch.cuda.synchronize()
        other_time_start = time.time()
        if isinstance(outputs[0], torch.Tensor):
            outputs = [x.cpu() for x in outputs]
        torch.cuda.synchronize()
        other_time = round((time.time()-other_time_start)*1000,2)
        time_end = time.perf_counter()
        logger.info(f'ImageEncoder forward {len(inputs)} images, '
                    f'cost {time_end - time_start:.3f}s')
        logger.info('pre cost time {} ms, forward cost time {} ms'.format(pre_time, forward_time))
        logger.info('other cosr {} ms'.format(other_time))
        return outputs

出现较大时延的代码段为:

if isinstance(outputs[0], torch.Tensor):
            outputs = [x.cpu() for x in outputs]

在服务启动后,该传输过程的时延达到1s多,严重拖慢了首token的时延,是否有可能对这部分进行一些优化,减少这个传输过程的影响

Reproduction

启动命令 lmdeploy serve api_server /multimodal/model-zoo/InternVL2-26B --backend turbomind --server-port 23333 --chat-template /multimodal/model-zoo/chat_template/chat_template.json --tp 4 --log-level DEBUG

Environment

sys.platform: linux
Python: 3.8.10 (default, Nov 22 2023, 10:22:35) [GCC 9.4.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3: NVIDIA GeForce RTX 4090
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.8, V11.8.89
GCC: x86_64-linux-gnu-gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
PyTorch: 2.1.0+cu118
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 11.8
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.7
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

TorchVision: 0.16.0+cu118
LMDeploy: 0.6.1+54b7230
transformers: 4.46.0.dev0
gradio: 3.50.2
fastapi: 0.111.0
pydantic: 2.7.1
triton: 2.1.0
NVIDIA Topology: 
        GPU0    GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PIX     SYS     SYS     0-87    0               N/A
GPU1    PIX      X      SYS     SYS     0-87    0               N/A
GPU2    SYS     SYS      X      PIX     88-175  1               N/A
GPU3    SYS     SYS     PIX      X      88-175  1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Error traceback

No response

zhuchen1109 commented 1 week ago

这个耗时应该是同步等待gpu执行结果返回,也就是说,实际gpu推理耗时是包含这个等待时间的

sjzhou4 commented 5 days ago

same as #2604