InternLM / lmdeploy

LMDeploy is a toolkit for compressing, deploying, and serving LLMs.
https://lmdeploy.readthedocs.io/en/latest/
Apache License 2.0
4.23k stars 381 forks source link

[Bug] 使用华为昇腾平台 昇腾 910A 显卡推理时报错:Get regInfo failed, The binary_info_config.json of socVersion [ascend910] does not support opType [ApplyRotaryPosEmb]. #2467

Closed XYZliang closed 4 days ago

XYZliang commented 1 week ago

Checklist

Describe the bug

感谢对国产化平台的支持,发现 0.6.0 版本添加了对PyTorchEngine 华为昇腾平台的支持,立马进行了测试。模型似乎成功加载,但是在进行推理时会报错,具体请看下面日志。

Reproduction

lmdeploy serve api_server --backend pytorch --device ascend /home/ma-user/work/model/qwen/Qwen2-7B-Instruct 其中/home/ma-user/work/model/qwen/Qwen2-7B-Instruct是下载的Qwen2-7B-Instruct模型,并修改 config.json 中的 torch_dtypefloat16昇腾 910 系列并不支持 bf16,lmdeploy也暂未支持显式指定lmdeploy serve api_server --backend pytorch --device ascend参考文档在昇腾设备上部署的示例

Environment

lmdeploy check_env 的结果如下:

sys.platform: linux
Python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:27:42) [GCC 10.4.0]
CUDA available: False
MUSA available: False
numpy_random_seed: 2147483648
GCC: gcc (GCC) 14.2.0
PyTorch: 2.1.0
PyTorch compiling details: PyTorch built with:
  - GCC 10.2
  - C++ Version: 201703
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: NO AVX
  - Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CXX_COMPILER=/opt/rh/devtoolset-10/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=open, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

TorchVision: 0.16.0
LMDeploy: 0.6.0+e2aa4bd
transformers: 4.44.2
gradio: Not Found
fastapi: 0.114.0
pydantic: 2.9.1
triton: Not Found

机器环境为: 使用国产华为欧拉+国产arm芯片+8卡昇腾910A npu的华为modelArts平台内,使用华为官方的pytorch_2.1.0-cann_8.0.rc1-py_3.9-euler_2.10.7-aarch64的modelArts镜像

使用模型为支持列表中的 qwen2 系列 7B 模型,来自modelscope的下载

安装遵守文档顺序,新建 3.10 环境后首先手动安装 torch 和 torch-npu(2.1.0),克隆最新仓库(20240914中午),从源码进行安装

Error traceback

 main U:1  ❲c❳ lmdeploy  ~/work/lmdeploy                                                                                                                                                                                 15:06:54  ma-user 
❯ lmdeploy serve api_server --backend pytorch --device ascend /home/ma-user/work/model/qwen/Qwen2-7B-Instruct 
/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch_npu/utils/path_manager.py:79: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/latest owner does not match the current user.
  warnings.warn(f"Warning: The {path} owner does not match the current user.")
/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch_npu/utils/path_manager.py:79: UserWarning: Warning: The /usr/local/Ascend/ascend-toolkit/8.0.RC1/aarch64-linux/ascend_toolkit_install.info owner does not match the current user.
  warnings.warn(f"Warning: The {path} owner does not match the current user.")
/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch_npu/contrib/transfer_to_npu.py:211: ImportWarning: 
    *************************************************************************************************************
    The torch.Tensor.cuda and torch.nn.Module.cuda are replaced with torch.Tensor.npu and torch.nn.Module.npu now..
    The torch.cuda.DoubleTensor is replaced with torch.npu.FloatTensor cause the double type is not supported now..
    The backend in torch.distributed.init_process_group set to hccl now..
    The torch.cuda.* and torch.cuda.amp.* are replaced with torch.npu.* and torch.npu.amp.* now..
    The device parameters have been replaced with npu in the function below:
    torch.logspace, torch.randint, torch.hann_window, torch.rand, torch.full_like, torch.ones_like, torch.rand_like, torch.randperm, torch.arange, torch.frombuffer, torch.normal, torch._empty_per_channel_affine_quantized, torch.empty_strided, torch.empty_like, torch.scalar_tensor, torch.tril_indices, torch.bartlett_window, torch.ones, torch.sparse_coo_tensor, torch.randn, torch.kaiser_window, torch.tensor, torch.triu_indices, torch.as_tensor, torch.zeros, torch.randint_like, torch.full, torch.eye, torch._sparse_csr_tensor_unsafe, torch.empty, torch._sparse_coo_tensor_unsafe, torch.blackman_window, torch.zeros_like, torch.range, torch.sparse_csr_tensor, torch.randn_like, torch.from_file, torch._cudnn_init_dropout_state, torch._empty_affine_quantized, torch.linspace, torch.hamming_window, torch.empty_quantized, torch._pin_memory, torch.autocast, torch.load, torch.Generator, torch.Tensor.new_empty, torch.Tensor.new_empty_strided, torch.Tensor.new_full, torch.Tensor.new_ones, torch.Tensor.new_tensor, torch.Tensor.new_zeros, torch.Tensor.to, torch.nn.Module.to, torch.nn.Module.to_empty
    *************************************************************************************************************

  warnings.warn(msg, ImportWarning)
<frozen importlib._bootstrap>:914: ImportWarning: TEMetaPathFinder.find_spec() not found; falling back to find_module()
HINT:    Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
HINT:    Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
HINT:    Please open http://0.0.0.0:23333 in a browser for detailed api usage!!!
INFO:     Started server process [1720388]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:23333 (Press CTRL+C to quit)
INFO:     127.0.0.1:54202 - "POST /v1/chat/completions HTTP/1.1" 200 OK
/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch_npu/utils/storage.py:38: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if self.device.type != 'cpu':
2024-09-14 15:08:17,653 - lmdeploy - ERROR - Engine loop failed with error: call aclnnApplyRotaryPosEmb failed, detail:EZ1001: 2024-09-14-15:08:17.648.113 Get regInfo failed, The binary_info_config.json of socVersion [ascend910] does not support opType [ApplyRotaryPosEmb].
        TraceBack (most recent call last):
        Check nnopExecutor != nullptr failed

[ERROR] 2024-09-14-15:08:17 (PID:1720388, Device:0, RankID:-1) ERR01005 OPS internal error
Traceback (most recent call last):
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 941, in async_loop
    await self._async_loop()
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 931, in _async_loop
    await __step(True)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 917, in __step
    raise e
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 909, in __step
    raise out
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 853, in _async_loop_background
    await self._async_step_background(
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 732, in _async_step_background
    output = await self._async_model_forward(
  File "/home/ma-user/work/lmdeploy/lmdeploy/utils.py", line 237, in __tmp
    return (await func(*args, **kwargs))
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 630, in _async_model_forward
    ret = await __forward(inputs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 608, in __forward
    return await self.model_agent.async_forward(
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 332, in async_forward
    output = self._forward_impl(inputs,
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 299, in _forward_impl
    output = model_forward(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 154, in model_forward
    output = model(**input_dict)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/backends/graph_runner.py", line 25, in __call__
    return self.model(**kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 340, in forward
    hidden_states = self.model(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 278, in forward
    hidden_states, residual = decoder_layer(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 194, in forward
    hidden_states = self.self_attn(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 80, in forward
    query_states, key_states = self.apply_rotary_pos_emb(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/nn/rotary_embedding.py", line 48, in forward
    return self.impl.forward(query, key, cos, sin, inplace)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/backends/ascend/apply_rotary_emb.py", line 26, in forward
    return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py", line 22, in apply_rotary_pos_emb
    ext_ops.apply_rotary_pos_emb(query_states_reshaped,
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/dlinfer/ops/llm.py", line 82, in apply_rotary_pos_emb
    return vendor_ops_registry["apply_rotary_pos_emb"](
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/dlinfer/vendor/ascend/torch_npu_ops.py", line 51, in apply_rotary_pos_emb
    return torch.ops.npu.npu_apply_rotary_pos_emb(query, key, cos, sin, "BSND")
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/_ops.py", line 692, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: call aclnnApplyRotaryPosEmb failed, detail:EZ1001: 2024-09-14-15:08:17.648.113 Get regInfo failed, The binary_info_config.json of socVersion [ascend910] does not support opType [ApplyRotaryPosEmb].
        TraceBack (most recent call last):
        Check nnopExecutor != nullptr failed

[ERROR] 2024-09-14-15:08:17 (PID:1720388, Device:0, RankID:-1) ERR01005 OPS internal error
ERROR:    Traceback (most recent call last):
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/queues.py", line 159, in get
    await getter
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/tasks.py", line 456, in wait_for
    return fut.result()
asyncio.exceptions.CancelledError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 171, in __no_threadsafe_get
    return await asyncio.wait_for(self.resp_que.get(), timeout)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/tasks.py", line 458, in wait_for
    raise exceptions.TimeoutError() from exc
asyncio.exceptions.TimeoutError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/base_events.py", line 633, in run_until_complete
    self.run_forever()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/base_events.py", line 600, in run_forever
    self._run_once()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/base_events.py", line 1896, in _run_once
    handle._run()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 253, in wrap
    await func()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 242, in stream_response
    async for chunk in self.body_iterator:
  File "/home/ma-user/work/lmdeploy/lmdeploy/serve/openai/api_server.py", line 437, in completion_stream_generator
    async for res in result_generator:
  File "/home/ma-user/work/lmdeploy/lmdeploy/serve/async_engine.py", line 553, in generate
    async for outputs in generator.async_stream_infer(
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine_instance.py", line 175, in async_stream_infer
    resp = await self.req_sender.async_recv(req_id)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 312, in async_recv
    resp: Response = await self._async_resp_get()
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 187, in _async_resp_get
    return await __no_threadsafe_get()
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 175, in __no_threadsafe_get
    exit(1)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/_sitebuiltins.py", line 26, in __call__
    raise SystemExit(code)
SystemExit: 1

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 700, in lifespan
    await receive()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/uvicorn/lifespan/on.py", line 137, in receive
    return await self.receive_queue.get()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/queues.py", line 159, in get
    await getter
asyncio.exceptions.CancelledError

ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 257, in __call__
    await wrap(partial(self.listen_for_disconnect, receive))
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 253, in wrap
    await func()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 230, in listen_for_disconnect
    message = await receive()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 534, in receive
    await self.message_event.wait()
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/locks.py", line 214, in wait
    await fut
asyncio.exceptions.CancelledError

During handling of the above exception, another exception occurred:

  + Exception Group Traceback (most recent call last):
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/uvicorn/protocols/http/h11_impl.py", line 406, in run_asgi
  |     result = await app(  # type: ignore[func-returns-value]
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/uvicorn/middleware/proxy_headers.py", line 70, in __call__
  |     return await self.app(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/fastapi/applications.py", line 1054, in __call__
  |     await super().__call__(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/applications.py", line 113, in __call__
  |     await self.middleware_stack(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/errors.py", line 165, in __call__
  |     await self.app(scope, receive, _send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/cors.py", line 85, in __call__
  |     await self.app(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 62, in __call__
  |     await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
  |     await app(scope, receive, sender)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 715, in __call__
  |     await self.middleware_stack(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 735, in app
  |     await route.handle(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 288, in handle
  |     await self.app(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 76, in app
  |     await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
  |     await app(scope, receive, sender)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/routing.py", line 74, in app
  |     await response(scope, receive, send)
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 250, in __call__
  |     async with anyio.create_task_group() as task_group:
  |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 680, in __aexit__
  |     raise BaseExceptionGroup(
  | exceptiongroup.BaseExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
  +-+---------------- 1 ----------------
    | Traceback (most recent call last):
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/queues.py", line 159, in get
    |     await getter
    | asyncio.exceptions.CancelledError
    | 
    | During handling of the above exception, another exception occurred:
    | 
    | Traceback (most recent call last):
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/tasks.py", line 456, in wait_for
    |     return fut.result()
    | asyncio.exceptions.CancelledError
    | 
    | The above exception was the direct cause of the following exception:
    | 
    | Traceback (most recent call last):
    |   File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 171, in __no_threadsafe_get
    |     return await asyncio.wait_for(self.resp_que.get(), timeout)
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/tasks.py", line 458, in wait_for
    |     raise exceptions.TimeoutError() from exc
    | asyncio.exceptions.TimeoutError
    | 
    | During handling of the above exception, another exception occurred:
    | 
    | Traceback (most recent call last):
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/runners.py", line 44, in run
    |     return loop.run_until_complete(main)
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/base_events.py", line 633, in run_until_complete
    |     self.run_forever()
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/base_events.py", line 600, in run_forever
    |     self._run_once()
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/base_events.py", line 1896, in _run_once
    |     handle._run()
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/asyncio/events.py", line 80, in _run
    |     self._context.run(self._callback, *self._args)
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 253, in wrap
    |     await func()
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/starlette/responses.py", line 242, in stream_response
    |     async for chunk in self.body_iterator:
    |   File "/home/ma-user/work/lmdeploy/lmdeploy/serve/openai/api_server.py", line 437, in completion_stream_generator
    |     async for res in result_generator:
    |   File "/home/ma-user/work/lmdeploy/lmdeploy/serve/async_engine.py", line 553, in generate
    |     async for outputs in generator.async_stream_infer(
    |   File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine_instance.py", line 175, in async_stream_infer
    |     resp = await self.req_sender.async_recv(req_id)
    |   File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 312, in async_recv
    |     resp: Response = await self._async_resp_get()
    |   File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 187, in _async_resp_get
    |     return await __no_threadsafe_get()
    |   File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 175, in __no_threadsafe_get
    |     exit(1)
    |   File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/_sitebuiltins.py", line 26, in __call__
    |     raise SystemExit(code)
    | SystemExit: 1
    +------------------------------------
/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/tempfile.py:860: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmprizhx3x0'>
  _warnings.warn(warn_message, ResourceWarning)
Future exception was never retrieved
future: <Future finished exception=RuntimeError('call aclnnApplyRotaryPosEmb failed, detail:EZ1001: 2024-09-14-15:08:17.648.113 Get regInfo failed, The binary_info_config.json of socVersion [ascend910] does not support opType [ApplyRotaryPosEmb].\n        TraceBack (most recent call last):\n        Check nnopExecutor != nullptr failed\n\n[ERROR] 2024-09-14-15:08:17 (PID:1720388, Device:0, RankID:-1) ERR01005 OPS internal error')>
Traceback (most recent call last):
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
    task.result()
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 941, in async_loop
    await self._async_loop()
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 931, in _async_loop
    await __step(True)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 917, in __step
    raise e
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 909, in __step
    raise out
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 853, in _async_loop_background
    await self._async_step_background(
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 732, in _async_step_background
    output = await self._async_model_forward(
  File "/home/ma-user/work/lmdeploy/lmdeploy/utils.py", line 237, in __tmp
    return (await func(*args, **kwargs))
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 630, in _async_model_forward
    ret = await __forward(inputs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/engine.py", line 608, in __forward
    return await self.model_agent.async_forward(
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 332, in async_forward
    output = self._forward_impl(inputs,
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 299, in _forward_impl
    output = model_forward(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/engine/model_agent.py", line 154, in model_forward
    output = model(**input_dict)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/backends/graph_runner.py", line 25, in __call__
    return self.model(**kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 340, in forward
    hidden_states = self.model(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 278, in forward
    hidden_states, residual = decoder_layer(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 194, in forward
    hidden_states = self.self_attn(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/models/qwen2.py", line 80, in forward
    query_states, key_states = self.apply_rotary_pos_emb(
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/nn/rotary_embedding.py", line 48, in forward
    return self.impl.forward(query, key, cos, sin, inplace)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/backends/ascend/apply_rotary_emb.py", line 26, in forward
    return apply_rotary_pos_emb(query, key, cos, sin, q_embed, k_embed)
  File "/home/ma-user/work/lmdeploy/lmdeploy/pytorch/kernels/ascend/apply_rotary_pos_emb.py", line 22, in apply_rotary_pos_emb
    ext_ops.apply_rotary_pos_emb(query_states_reshaped,
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/dlinfer/ops/llm.py", line 82, in apply_rotary_pos_emb
    return vendor_ops_registry["apply_rotary_pos_emb"](
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/dlinfer/vendor/ascend/torch_npu_ops.py", line 51, in apply_rotary_pos_emb
    return torch.ops.npu.npu_apply_rotary_pos_emb(query, key, cos, sin, "BSND")
  File "/home/ma-user/anaconda3/envs/lmdeploy/lib/python3.10/site-packages/torch/_ops.py", line 692, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: call aclnnApplyRotaryPosEmb failed, detail:EZ1001: 2024-09-14-15:08:17.648.113 Get regInfo failed, The binary_info_config.json of socVersion [ascend910] does not support opType [ApplyRotaryPosEmb].
        TraceBack (most recent call last):
        Check nnopExecutor != nullptr failed

[ERROR] 2024-09-14-15:08:17 (PID:1720388, Device:0, RankID:-1) ERR01005 OPS internal error
sys:1: ResourceWarning: unclosed <socket.socket fd=93, family=AddressFamily.AF_INET, type=SocketKind.SOCK_STREAM, proto=6, laddr=('0.0.0.0', 23333)>
CyCle1024 commented 1 week ago

目前lmdeploy pytorch engine对于昇腾平台的支持为 Atlas 800T A2机型(Atlas 800I 推理系列产品尚未测试,但是应该是能够支持的),目前没有支持910A。在华为CANN官方文档上大模型部分的融合算子都是不支持 910A 的。如下图所示普通算子 aclnnAbs 支持 910A 但是大模型融合算子 aclnnIncreFlashAttentionV4aclnnApplyRotaryPosEmb 不支持 910A。详情见链接image image image

XYZliang commented 4 days ago

感谢您的解答,希望华为能加快适配。。

XYZliang commented 4 days ago

通过内部工作人员询问,800T A2内置的是910B的芯片,910B 的兄弟可以尝试尝试,这条Issues给用华为昇腾的兄弟提个醒吧

CyCle1024 commented 3 days ago

感谢您的解答,希望华为能加快适配。。

据我所知,华为这边针对老的910A芯片可能不太会支持大模型算子了,如果是老的硬件在大模型领域确实不好办。

XYZliang commented 2 days ago

感谢您的解答,希望华为能加快适配。。

据我所知,华为这边针对老的910A芯片可能不太会支持大模型算子了,如果是老的硬件在大模型领域确实不好办。

是,昨天也跟专家咨询过。难办,准备换 910B 了