Open david-sitsky opened 3 months ago
To save time, the whisper_small.onnx
file can be downloaded from here: https://drive.google.com/file/d/10yDz-VI-iKsszNgyOjbRqYTfwUHap_a-/view?usp=sharing.
It is a known issue that some operators are not thread-safe. Like Attention and MultiHeadAttention used in whisper encoder is not thread safe. You may try set an environment variable ORT_DISABLE_FUSED_ATTENTION=1
. However, that will increase latency since some fused attention kernel are disabled.
Another walkaround is to use different session per thread, that could avoid thread safe issue, but probably won't help performance since multiple session competing the same GPU resource.
What's the reason to use use multiple threading for same session? Usually have no performance benefit. Instead, you can try increasing batch size to see whether it could increase throughput.
Setting those environment variables did not help sadly. I still see similar errors, but also new ones that look even worse:
2024-07-19 05:32:03.173536756 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running WhisperBeamSearch node. Name:'BeamSearch_node' Status Message: Non-zero status code returned while running MatMul node. Name:'/whisper_decoder_init/proj_out/MatMul' Status Message: CUBLAS failure 14: CUBLAS_STATUS_INTERNAL_ERROR ; GPU=0 ; hostname=ip-172-31-31-23 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/math/matmul.cc ; line=312 ; expr=cublasGemmHelper( GetCublasHandle(ctx), transB, transA, static_cast<int>(helper.N()), static_cast<int>(helper.M()), static_cast<int>(helper.K()), &alpha, reinterpret_cast<const CudaT*>(right_X->Data<T>()), ldb, reinterpret_cast<const CudaT*>(left_X->Data<T>()), lda, &zero, reinterpret_cast<CudaT*>(Y->MutableData<T>()), ldc, device_prop, UseTF32());
Aborted (core dumped)
I definitely want to use the same session so the model is only loaded once into GPU memory. It is true batching can be used here instead and I'll look into that.
I was initially looking at threading since the whisper model pre-processing, which reads the audio data and converts it into the appropriate format (padding, then doing log mel spectrogram conversion) is CPU work only which could be parallelised. With a batching approach this pre-processing work will be serialised.
The documents I read indicated that the ORT session is thread-safe. Since this is not the case, is there a list of operators documented somewhere that are not?
@david-sitsky, you may try use multiple threading for pre-processing in CPU to see whether it helps.
The new error indicates that other places have thread-safe issue in CUDA provider. It may take time to nail the root cause.
A related older issue: https://github.com/microsoft/onnxruntime/issues/18806
@tianleiwu - any ideas with next steps with the thread-safe issue? I am using djl-serving, so sometimes the server will receive multiple unrelated requests from different clients on the same Whisper model, and I see the issue being hit pretty easily. I am using OnnxRuntime 1.17.3.
2024-08-06 04:54:13.828882272 [E:onnxruntime:ort-java, cuda_call.cc:116 CudaCall] CUBLAS failure 14: CUBLAS_STATUS_INTERNAL_ERROR ; GPU=0 ; hostname=ip-172-31-31-23 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/math/matmul.cc ; line=324 ; expr=cublasGemmHelper( GetCublasHandle(ctx), transB, transA, static_cast<int>(helper.N()), static_cast<int>(helper.M()), static_cast<int>(helper.K()), &alpha, reinterpret_cast<const CudaT*>(right_X->Data<T>()), ldb, reinterpret_cast<const CudaT*>(left_X->Data<T>()), lda, &zero, reinterpret_cast<CudaT*>(Y->MutableData<T>()), ldc, device_prop);
2024-08-06 04:54:13.828924164 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running MatMul node. Name:'/whisper_decoder_init/proj_out/MatMul' Status Message: CUBLAS failure 14: CUBLAS_STATUS_INTERNAL_ERROR ; GPU=0 ; hostname=ip-172-31-31-23 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/math/matmul.cc ; line=324 ; expr=cublasGemmHelper( GetCublasHandle(ctx), transB, transA, static_cast<int>(helper.N()), static_cast<int>(helper.M()), static_cast<int>(helper.K()), &alpha, reinterpret_cast<const CudaT*>(right_X->Data<T>()), ldb, reinterpret_cast<const CudaT*>(left_X->Data<T>()), lda, &zero, reinterpret_cast<CudaT*>(Y->MutableData<T>()), ldc, device_prop);
2024-08-06 04:54:13.829005579 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running WhisperBeamSearch node. Name:'BeamSearch_node' Status Message: Non-zero status code returned while running MatMul node. Name:'/whisper_decoder_init/proj_out/MatMul' Status Message: CUBLAS failure 14: CUBLAS_STATUS_INTERNAL_ERROR ; GPU=0 ; hostname=ip-172-31-31-23 ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/math/matmul.cc ; line=324 ; expr=cublasGemmHelper( GetCublasHandle(ctx), transB, transA, static_cast<int>(helper.N()), static_cast<int>(helper.M()), static_cast<int>(helper.K()), &alpha, reinterpret_cast<const CudaT*>(right_X->Data<T>()), ldb, reinterpret_cast<const CudaT*>(left_X->Data<T>()), lda, &zero, reinterpret_cast<CudaT*>(Y->MutableData<T>()), ldc, device_prop);
terminate called after throwing an instance of 'onnxruntime::OnnxRuntimeException'
@david-sitsky, The BeamSearch operator is not thread safe since it has internal state.
One way for serving is to have a queue for user's requests, and batch user's request to call onnxruntime inference to avoid multiple threading. I am not sure whether it is supported by bjl-serving.
I could take a look at design change to make it thread safe. That might be targeted for 1.20 release (3+ months away).
Many thanks @tianleiwu - that would be great!
Describe the issue
I created a Whisper ONNX model using https://github.com/microsoft/Olive/blob/main/examples/whisper/README.md, specifically using these commands on a machine with a GPU:
In my application, I break up large audio files into chunks and then execute each chunk against the Whisper model for transcription using threads.
On a CPU machine (with a model generated on a non-GPU machine) this works fine. However on a GPU machine this fails.
To reproduce
On a g5.4xlarge instance in AWS, using Ubuntu 22.04, CUDA 11.8, Nvidia A10G GPU, the following program reproduces the error:
An example of an error run:
The program works fine when
THREAD_NUMBER=1
.The speech.wav file can be downloaded from https://resources.djl.ai/audios/speech.wav.
If it helps, I can try and put the model I generated somewhere, but it is 1.1G in size.
Urgency
This is a blocker for deploying our application so it is urgent. We are actually using the Java bindings of onnxruntime with https://github.com/deepjavalibrary/djl-serving which is hitting this very issue when it receives concurrent requests. I wrote the python program so that it is easier to reproduce, but it is exactly the same error message.
Platform
Linux
OS Version
Ubuntu 22.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
onnxruntime-gpu-1.18.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
CUDA 11.8