Open tianleiwu opened 1 year ago
Hi @feihugis - I recall you saying that the model your team flighted also used CUDA Graph. Did you run into issues like the above while trying to capture the graph ? AFAIK - Cuda stream synchronize has always existed in the code. I wonder why we didn't see something like this while testing your model.
@tianleiwu - Could it be that in the "large" unet model, it is using a kernel that internally uses cudaStreamSynchronize()
? This may be one of the cases where we can't use CUDA Graphs unfortunately.
For the "small" model, it may be that the stream synchronize using op/kernel doesn't kick-in? If you look at the CUDA EP setup that captures the graph, we first finish capturing the graph in OnRunEnd()
here - https://github.com/microsoft/onnxruntime/blob/a8ad0edbeb45a1733d5b062acc13c6b3ad08731b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc#L387 and only then do the stream sync here - https://github.com/microsoft/onnxruntime/blob/a8ad0edbeb45a1733d5b062acc13c6b3ad08731b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc#L397 before returning control back to the caller.
Unfortunately, if one of the intermediate kernels it encounters between graph capture begin and graph capture end contains synchronization logic, it cannot be captured.
Hi @feihugis - I recall you saying that the model your team flighted also used CUDA Graph. Did you run into issues like the above while trying to capture the graph ? AFAIK - Cuda stream synchronize has always existed in the code. I wonder why we didn't see something like this while testing your model.
Hi @hariharans29 and @tianleiwu sorry for the late response. I did not see this message and suddenly saw it when I search my email for something else.
Yes, the model we had mainstreamed around one year ago did not meet any issue when capturing the CUDA graph.
Recently when I tried GPT2+Beam Search, I met similar issues. After making some codes changes (https://github.com/feihugis/onnxruntime/commit/de67b88bb775e7700f9a685511f0fab391c24cd6), CUDA Graph capturing can work, but as some of ops are not on GPU, the outputs are not correct.
Please feel free to ping me on Team if I missed your comments.
I still see this error when running multiple models in parallel. You can reproduce the error by running:
./onnx_test_runner -e cuda /data/onnx
The folder /data/onnx holds test models and their input/output data from https://github.com/onnx/onnx
2024-07-23 16:30:08.420038342 [E:onnxruntime:Default, dataitem_request.cc:32 operator()] argmin_default_axis_random:Non-zero status code returned while running ArgMin node. Name:'' Status Message: CUDA error cudaErrorStreamCaptureUnsupported:operation not permitted when stream is capturing 2024-07-23 16:30:08.889316320 [E:onnxruntime:clip, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Clip node. Name:'' Status Message: CUDA error cudaErrorStreamCaptureUnsupported:operation not permitted when stream is capturing
@snnn, this issue is for cuda graph error in single thread. Your reported error is another issue of multi-threading.
Stream capturing error shall not appear when cuda graph is not enabled. If you see that error in onnx test runner, that basically means ORT has some code is not thread-safe, which cause buffer overrun and mess up the call stack.
Describe the issue
During cuda graph catpure, ORT will trigger cudaStreamSynchronize, which is not allowed in CUDA graph catpure. Call stack is like the following:
Error is like the following (I added file and line):
To reproduce
The error is not always triggered with small model. But with larger model like unet, it can always reproduce.
Urgency
No response
Platform
Linux
OS Version
Ubuntu 20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.14.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
No response