Open LinJianping opened 1 week ago
It cannot be reproduced with the latest main branch.
I started getting this error with PyTorch engine in the latest release for Qwen2-VL model. I get the error with batch size >= 6. When batch size is 1, everything runs fine.
I still can not reproduce the error.
Since there is a sgemm cublas error report, try replace
with
freqs = (inv_freq_expanded.float()
* position_ids_expanded.float()).transpose(1, 2)
It cannot be reproduced with the latest main branch.
When batch size is 1, everything runs fine. When batch_size is set to 8, the above error occasionally occurs during loop execution. I think it may have something to do with the cuda version. My cuda version is 12.2. When I use the default lmdeploy-0.6.1-cp39-cp39-manylinux2014_x86_64.whl, the probability of anomalies is very high, but when I use lmdeploy-0.6.1+cu118-cp39-cp39-manylinux2014_x86_64.whl, the probability of anomalies is reduced.
I have asked an expert, the error might come from the vision model on the default stream. Which would corruption the capturing of language model in the other stream. I will try fix it ASAP.
I have asked an expert, the error might come from the vision model on the default stream. Which would corruption the capturing of language model in the other stream. I will try fix it ASAP.
Another question is, when I use triton python backend to deploy, and set dynamic batching, is it also easy to cause exceptions due to cuda graph capture of different batch_sizes?
We would capture multiple graphs with different input sizes, and the input would be padded to the capture size before forward. It is safe to use dynamic batching.
We would capture multiple graphs with different input sizes, and the input would be padded to the capture size before forward. It is safe to use dynamic batching.
What is the specific capture strategy like? For example, the default capture batch size options may be 1, 2, 4, 8, etc. In this way, I can set the corresponding prefer batch size to obtain the best inference performance.
Another curious question is why TurboMind supports the 2B-76B InternVL2 model but not the 1B model. Are there any plans to support it in the future? @grimoire
https://github.com/grimoire/lmdeploy/tree/fix-vl-graphcapture I have set the capture mode to thread_local, which might fix the bug.
What is the specific capture strategy like?
The engine would generate graphs with token numbers [1, 2, 4,..., 256], you don't have to care much about that since pytorch engine would schedule the requests to the best batch size.
why TurboMind supports the 2B-76B InternVL2 model but not the 1B model
Intervl2-1b use qwen2-0.5b as it's language model, which has head_size=64
. Turbomind does not support head_size<=128.
https://github.com/grimoire/lmdeploy/tree/fix-vl-graphcapture I have set the capture mode to thread_local, which might fix the bug.
Seems like it is branched not from the latest version but from 0.4.2. When installing, it downgraded pytorch and I get this error for Qwen2-VL: "Unrecognized configuration class <class 'transformers.models.qwen2_vl.configuration_qwen2_vl.Qwen2VLConfig'> for this kind of AutoModel".
Are you using the main branch of my repo? I have create a draft PR https://github.com/InternLM/lmdeploy/pull/2560, Please try this.
Sorry, forgot to switch branches! Yes, the issue doesn't occur when using the correct branch.
https://github.com/grimoire/lmdeploy/tree/fix-vl-graphcapture I have set the capture mode to thread_local, which might fix the bug.
What is the specific capture strategy like?
The engine would generate graphs with token numbers [1, 2, 4,..., 256], you don't have to care much about that since pytorch engine would schedule the requests to the best batch size.
why TurboMind supports the 2B-76B InternVL2 model but not the 1B model
Intervl2-1b use qwen2-0.5b as it's language model, which has
head_size=64
. Turbomind does not support head_size<=128.
This method works well for me. Also, I'm curious to know if there's any plan to bring TurboMind support to smaller models like Intervl2-1b, or is the workload too heavy to make it happen in the near future?
I'm curious to know if there's any plan to bring TurboMind support to smaller models like Intervl2-1b
@lvhan028 @lzhangzz
Checklist
Describe the bug
使用lmdeploy v0.6.0加载InternVL2-1B,在循环中执行推理会报“RuntimeError: CUDA error: operation not permitted when stream is capturing”,怀疑跟v0.6.0支持cuda graph有关。
Reproduction
` import os import time import torch from lmdeploy import pipeline, TurbomindEngineConfig from lmdeploy.vl import load_image
device = "cuda" pwd = os.path.abspath(os.path.dirname(file)) model_path = os.path.join(pwd, 'InternVL2-1B') pipe = pipeline(model_path, backend_config=TurbomindEngineConfig(cache_max_entry_count=0.6))
BATCH_SIZE = 8
querys = [ '图片中有海吗', ]*BATCH_SIZE
image_paths = [os.path.join(pwd, "warmup/flag.jpg")]*BATCH_SIZE
image = load_image(image_paths[1]) response = pipe((querys[1], image)) prompts = [(query, load_image(img_url)) for img_url, query in zip(image_paths, querys)] response = pipe(prompts) print(response) REPEAT = 100 tic = time.time() torch.cuda.synchronize() for in range(_REPEAT): response = pipe(prompts) torch.cuda.synchronize() toc = time.time() print(response) print(f'seconds per image:{(toc-tic)/BATCH_SIZE/_REPEAT}') `
Environment
Error traceback