InternLM / lmdeploy

LMDeploy is a toolkit for compressing, deploying, and serving LLMs.
https://lmdeploy.readthedocs.io/en/latest/
Apache License 2.0
4.56k stars 409 forks source link

[Bug] 使用InternVL2-4B版本时flash_attention出错 #2216

Open wonderingtom opened 3 months ago

wonderingtom commented 3 months ago

Checklist

Describe the bug

Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Phi3ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the with torch.autocast(device_type='torch_device'): decorator, or load the model with the torch_dtype argument. Example: model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Phi3Model is torch.float32. You should run training or inference using Automatic Mixed-Precision via the with torch.autocast(device_type='torch_device'): decorator, or load the model with the torch_dtype argument. Example: model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)

推理可正常完成,但不清楚这一问题是否会影响推理速度

Reproduction

主体代码仅运行pipe((prompt, imgs), gen_config=gen_config)

Environment

sys.platform: linux
Python: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]
CUDA available: True
MUSA available: False
numpy_random_seed: 2147483648
GPU 0,1,2,3,4,5,6,7,8,9: NVIDIA A100 80GB PCIe
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 11.6, V11.6.55
GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.3) 9.4.0
PyTorch: 2.2.2+cu121
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.3.2 (Git Hash 2dc95a2ad0841e29db8b22fbccaf3e5da7992b01)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90
  - CuDNN 8.9.2
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wsuggest-override -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=2.2.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, USE_ROCM_KERNEL_ASSERT=OFF, 

TorchVision: 0.17.2+cu121
LMDeploy: 0.5.1+6245346
transformers: 4.43.0.dev0
gradio: Not Found
fastapi: 0.111.1
pydantic: 2.8.2
triton: 2.2.0
NVIDIA Topology: 
        GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    GPU8    GPU9    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PXB     PXB     PXB     PXB     SYS     SYS     SYS     SYS     SYS     0-39,80-119     0               N/A
GPU1    PXB      X      PIX     PXB     PXB     SYS     SYS     SYS     SYS     SYS     0-39,80-119     0               N/A
GPU2    PXB     PIX      X      PXB     PXB     SYS     SYS     SYS     SYS     SYS     0-39,80-119     0               N/A
GPU3    PXB     PXB     PXB      X      PIX     SYS     SYS     SYS     SYS     SYS     0-39,80-119     0               N/A
GPU4    PXB     PXB     PXB     PIX      X      SYS     SYS     SYS     SYS     SYS     0-39,80-119     0               N/A
GPU5    SYS     SYS     SYS     SYS     SYS      X      PXB     PXB     PXB     PXB     40-79,120-159   1               N/A
GPU6    SYS     SYS     SYS     SYS     SYS     PXB      X      PIX     PXB     PXB     40-79,120-159   1               N/A
GPU7    SYS     SYS     SYS     SYS     SYS     PXB     PIX      X      PXB     PXB     40-79,120-159   1               N/A
GPU8    SYS     SYS     SYS     SYS     SYS     PXB     PXB     PXB      X      PIX     40-79,120-159   1               N/A
GPU9    SYS     SYS     SYS     SYS     SYS     PXB     PXB     PXB     PIX      X      40-79,120-159   1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Error traceback

No response

RunningLeon commented 3 months ago

@wonderingtom hi, 没有复现。你修改模型了吗?config.json里的torch_dtype是多少? 你确认下是哪里报的warning。理论上说,phi3的flash-attn应该已经替换了,不会调到flash-attn。你debug看看这个替换的model的dtype: https://github.com/InternLM/lmdeploy/blob/030c501615ee5aae6be124dc794ca701eb025d2a/lmdeploy/pytorch/models/phi3.py#L207

wonderingtom commented 3 months ago

@RunningLeon 您好,我使用的模型直接从huggingface上拷贝,没有进行修改,已确认torch_dtype=torch.bfloat16。您提到的这个model的dtype也为torch.bfloat16。warning产生的地方目前还没能确认。

RunningLeon commented 2 months ago

hi, 可以二分法debug下。这边没法复现您的问题