Open kiennguyen94 opened 2 years ago
@kiennguyen94: thanks for reporting this - we are on a similar issue https://github.com/microsoft/onnxruntime/issues/11548. And yes the mismatch between Z and Y is the culprit.
@kiennguyen94 : BTW, for your case mind share us a model and some sample script to double confirm the fix working?
@RandySheriffH Sorry for the late reply, here's the repro https://github.com/kiennguyen94/ort_load_repro.
Just need to extract the model tar tar xvzf ./citrinet.tgz
then run python ./final_repo.py
.
I have the fix in this temp draft PR https://github.com/microsoft/onnxruntime/pull/12366. This PR fixes this particular repro, but I'm not sure if it introduces unwanted behavior.
@kiennguyen94, thanks! For this issue your fix https://github.com/microsoft/onnxruntime/pull/12366 should work, plan to bring it in along with other fixes into https://github.com/microsoft/onnxruntime/tree/FuseConvShapeMismatch.
For anyone else who runs into this issue: we were able to circumvent this by by turning off optimizations.
sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
return ort.InferenceSession(onnx_path, providers=providers, sess_options=sess_opt)
The theory being that some of the optimizations lead to a graph which cannot be executed using CUDA, while the ONNX file itself is fine.
Describe the bug When running models with conv layers with optimization, ORT throws the following error
Urgency None
System information
ddb45e9
, also observed with 1.12To Reproduce
tar xvzf ./citrinet.tgz && python ./final_repo.py
Expected behavior
Screenshots None
Additional context
CUDNN_STATUS_BAD_PARAM
, which means either dimension ofZ
andY
don't match, or incompatible datatype. (per Cudnn docs)Z
is the output of some previous OP, it can have missing dimension, (egZ
is[1, 1024, 16]
whereasY
is[1, 1024, 16, 1]
)cuda/FusedConv
? If so, I think simply adding a dimension check if len(Z.shape) == len(Y.shape) - 1: extend Z.shape by 1, then setORT_RETURN_IF_ERROR(s_.z_tensor.Set(new_z_dim, CudnnTensor::GetDataType<CudaT>()));
right around here https://github.com/microsoft/onnxruntime/blob/de57daaab055766f6a5fa0c4ef34318fe611174b/onnxruntime/core/providers/cuda/nn/conv.cc#L231Reshape
onZ
beforeFusedConv
. But given thatcpu/FusedConv
works fine, I don't think we want this.