pykeio / ort

Fast ML inference & training for Rust with ONNX Runtime
https://ort.pyke.io/
Apache License 2.0
797 stars 91 forks source link

Error: `GetTensorMutableData` should not be a null pointer #185

Closed jamjamjon closed 5 months ago

jamjamjon commented 5 months ago

Problem

Error: GetTensorMutableData should not be a null pointer

This onnx model has 4 outputs. When fetch the "batchno_classid_y1x1y2x2" output, error occurs. image

Version

version = "2.0.0-alpha.4"

Model

https://github.com/jamjamjon/assets/releases/download/v0.0.1/yolopv2_dyn-480x800-f16.onnx

Code


        let ys = self.session.run(xs_.as_ref())?;

        // oputput
        let mut ys_ = Vec::new();

        for (dtype, name) in self.odtypes.iter().zip(self.onames.iter()) {
            let y = &ys[name.as_str()];

            let y_ = match &dtype {
                TensorElementType::Float32 => {
                    y.extract_tensor::<f32>()?.view().to_owned()
                },
                TensorElementType::Float16 => {
                    y.extract_tensor::<f16>()?.view().mapv(f16::to_f32)
                },
                TensorElementType::Int64 => {
                    y.extract_tensor::<i64>()?.view().to_owned().mapv(|x| x as f32)
                } 
                _ => todo!(),
            };
            ys_.push(y_);
        }
decahedron1 commented 5 months ago

I'd imagine this just happens when there's no detections (N is 0). In that case (with ort v2.0.0-rc.1) just check .shape() first and make sure N is > 0 before trying to extract.

let t = ys["batchno_classid_y1x1y2x2"].upcast_ref::<Tensor<i64>>()?;
let n_detections = t.shape()?[0];
if n_detections > 0 {
    ...
}