researchmm / Stark

[ICCV'21] Learning Spatio-Temporal Transformer for Visual Tracking
MIT License
645 stars 143 forks source link

stark_lighting model convert onnx error #69

Open Aspirinkb opened 2 years ago

Aspirinkb commented 2 years ago

When convert stark_lighting model for the search region branch by python3 tracking/ORT_lightning_X_trt_complete.py, a shape miss-match error occur as following, for the template branch is ok:

[E:onnxruntime:, sequential_executor.cc:346 Execute] Non-zero status code returned while running Reshape node. Name:'Reshape_148' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, std::vector<long int>&, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,128,400}, requested shape:{400,128,20,20}

Traceback (most recent call last):
  File "tracking/ORT_lightning_X_trt_complete.py", line 166, in <module>
    ort_outs = ort_session.run(None, ort_inputs)
  File "/home/inspur/.local/lib/python3.6/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 188, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Reshape node. Name:'Reshape_148' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/reshape_helper.h:41 onnxruntime::ReshapeHelper::ReshapeHelper(const onnxruntime::TensorShape&, std::vector<long int>&, bool) gsl::narrow_cast<int64_t>(input_shape.Size()) == size was false. The input tensor cannot be reshaped to the requested shape. Input shape:{1,128,400}, requested shape:{400,128,20,20}
Aspirinkb commented 2 years ago

The exported pytorch models are both ok.

notnitsuj commented 2 years ago

Hi @Aspirinkb, I'm facing the same issue. Did you find a way to resolve it?

xingxinghanzi commented 1 year ago

change fx_t = fx.view(fx.shape[:2], self.feat_sz_s, self.feat_sz_s).contiguous() to fx_t = fx.view(1, fx.shape[1], self.feat_sz_s, self.feat_sz_s).contiguous() the error is removed.