microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.41k stars 2.89k forks source link

FusedConv Cuda EP invalid argument error. #12321

Open kiennguyen94 opened 2 years ago

kiennguyen94 commented 2 years ago

Describe the bug When running models with conv layers with optimization, ORT throws the following error

2022-07-26 09:49:04.435452395 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDNN failure 3: CUDNN_STATUS_BAD_PARAM ; GPU=1 ; hostname=tiger09.som.ma ; expr=cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, &alpha, Base::s_.y_tensor, Base::s_.y_data);
2022-07-26 09:49:04.435491739 [E:onnxruntime:, sequential_executor.cc:368 Execute] Non-zero status code returned while running FusedConv node. Name:'Conv_125' Status Message: CUDNN error executing cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, &alpha, Base::s_.y_tensor, Base::s_.y_data)
Traceback (most recent call last):
  File "final_repo.py", line 57, in <module>
    output = model.run(None, {"audio_signal": au_sig, "length": length}, run_options=run_opt)
  File "/n/w1-knguyen/conda3/install/envs/py3_cuda/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running FusedConv node. Name:'Conv_125' Status Message: CUDNN error executing cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, &alpha, Base::s_.y_tensor, Base::s_.y_data)

Urgency None

System information

To Reproduce

Expected behavior

Screenshots None

Additional context

RandySheriffH commented 2 years ago

@kiennguyen94: thanks for reporting this - we are on a similar issue https://github.com/microsoft/onnxruntime/issues/11548. And yes the mismatch between Z and Y is the culprit.

RandySheriffH commented 2 years ago

@kiennguyen94 : BTW, for your case mind share us a model and some sample script to double confirm the fix working?

kiennguyen94 commented 2 years ago

@RandySheriffH Sorry for the late reply, here's the repro https://github.com/kiennguyen94/ort_load_repro. Just need to extract the model tar tar xvzf ./citrinet.tgz then run python ./final_repo.py.

I have the fix in this temp draft PR https://github.com/microsoft/onnxruntime/pull/12366. This PR fixes this particular repro, but I'm not sure if it introduces unwanted behavior.

RandySheriffH commented 2 years ago

@kiennguyen94, thanks! For this issue your fix https://github.com/microsoft/onnxruntime/pull/12366 should work, plan to bring it in along with other fixes into https://github.com/microsoft/onnxruntime/tree/FuseConvShapeMismatch.

xfbs commented 7 months ago

For anyone else who runs into this issue: we were able to circumvent this by by turning off optimizations.

sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
return ort.InferenceSession(onnx_path, providers=providers, sess_options=sess_opt)

The theory being that some of the optimizations lead to a graph which cannot be executed using CUDA, while the ONNX file itself is fine.