pykeio / ort

Fast ML inference & training for Rust with ONNX Runtime
https://ort.pyke.io/
Apache License 2.0
886 stars 102 forks source link

Different output values than in python #13

Closed LoipesMas closed 1 year ago

LoipesMas commented 1 year ago

Hi! Thanks for all your work! I gave this crate a try, and while it's easy to use, sadly I'm getting different (and thus wrong) outputs from my models.

Here is a gist for rust with ort: https://gist.github.com/LoipesMas/2d342b8087dbae4af31d8af2752e84de Here is a gist for python with onnxruntime: https://gist.github.com/LoipesMas/d7258a3d009e9b06c3684d77e341251b (Those are using the squeezenet models from here, but I originally run into this issue with a different, yolov7 based model)

Example outputs rust+ort: ``` [src/main.rs:29] &input.shape() = [ 1, 3, 224, 224, ] [src/main.rs:30] input.slice(s![0, .., 100, 100]) = [0.92156863, 0.9529412, 0.99607843], shape=[3], strides=[1], layout=CFcf (0xf), const ndim=1 [src/main.rs:31] input.slice(s![0, .., 180, 50]) = [0.9882353, 0.96862745, 0.95686275], shape=[3], strides=[1], layout=CFcf (0xf), const ndim=1 [src/main.rs:48] max_score = Some( ( 794, 0.06649929, ), ) [src/main.rs:49] scores.slice(s![0, 322, .., ..]) = [[1.08950435e-5]], shape=[1, 1], strides=[0, 0], layout=CFcf (0xf), const ndim=2 ``` python: ``` frame.shape=(1, 3, 224, 224) frame[0, :, 100, 100]=array([0.92156863, 0.9529412 , 0.99607843], dtype=float32) frame[0, :, 180, 50]=array([0.9882353 , 0.96862745, 0.95686275], dtype=float32) max_score=0.11967179 np.where(scores >= max_score)=(array([0]), array([669]), array([0]), array([0])) scores[0][322]=array([[4.9559356e-05]], dtype=float32) ``` ---

Input shapes and values are the same (I'm almost positive), but outputs are not even close (e.g. different indexes of max-value, different values (sometimes an order of magnitude different))

I'm not sure if that's an issue on my side (how I load data, how I use this crate or something else) or if it's on crate's side. Since ort is based on onnxruntime-rs, this issue might be relevant (although probably not very helpful). Maybe those small errors add up somehow? No idea.

Thanks in advance!

decahedron1 commented 1 year ago

In Python you are using the softmax outputs, but in Rust you are using the raw logits. You can use the softmax() function implemented in ort::tensor::NdArrayTensor to get the expected result.

LoipesMas commented 1 year ago

Thanks for fast reply. Unfortunately I don't think that's it. It seems that there already is a softmax in squeezenet, so I can't see why it wouldn't be included in the output. And the indexes of max-values are different (794 in rust, 669 in python), which is something softmax would not change, AFAIK. I tried doing softmax on the output, but, as I expected, the index of max-value didn't change.

decahedron1 commented 1 year ago

Oops, you're right, there already is a softmax, sorry.

What's your onnxruntime version in Python? Is the output different with a non-quantized model?

Also, it seems your inputs may not be correct:

The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

LoipesMas commented 1 year ago

What's your onnxruntime version in Python?

onnxruntime==1.13.1

Is the output different with a non-quantized model?

For non-quantized model the outputs are wrong as well (python version still has max-value at index 669, rust now has max-value at 111). And the yolov7-based model isn't quantized at all.

Also, it seems your inputs may not be correct:

Yes, because my goal wasn't to use squeezenet, I just chose it for demonstration. So the outputs might be "wrong", but they should be the same "wrong" in python and rust, right?

decahedron1 commented 1 year ago

Yes, the output should be the same with the same input data... This is really strange because ort doesn't mess with the input/output data. There shouldn't be any precision loss anywhere. I've deployed Stable Diffusion, GPT, and ALBERT with ort and haven't had any accuracy issues.

Whatever precision loss is occurring here may be exacerbated by the faulty preprocessing, because it shouldn't be this far off...

I think the Python API is hiding something. Can you try building the session with .with_optimization_level(GraphOptimizationLevel::Level1)? If that doesn't help, try with ::Disable or ::Level2.

LoipesMas commented 1 year ago

I tried all optimization levels, none of them changed anything.

I can try feeding squeezenet properly processed data, but the other model works fine with this data in python.

LoipesMas commented 1 year ago

For some reason I couldn't get squeezenet to give me meaningful results (even with correctly preprocessed data), so I went back to testing with the other model (the yolov7 based one). Here I can get correct results in python, while rust outputs are very wrong. Tried all optimization levels, again, no luck.

decahedron1 commented 1 year ago

Could you share your YOLOv7 Python & Rust code?

LoipesMas commented 1 year ago

Rust: https://gist.github.com/LoipesMas/c0271aeef1d45951c148c40312d91198 Python: https://gist.github.com/LoipesMas/c20233128d04e9a6daa8934cfb354f65

outputs rust: ``` [src/main.rs:28] &input.shape() = [ 1, 3, 705, 705, ] [src/main.rs:29] input.slice(s![0, .., 100, 100]) = [0.0033525568, 0.0036139947, 0.003921569], shape=[3], strides=[1], layout=CFcf (0xf), const ndim=1 [src/main.rs:30] input.slice(s![0, .., 180, 50]) = [0.0035217225, 0.0037062669, 0.003921569], shape=[3], strides=[1], layout=CFcf (0xf), const ndim=1 [src/main.rs:43] scores.shape() = [ 1, 77175, ] [src/main.rs:49] max_score = Some( ( 75041, 0.00014480948, ), ) [src/main.rs:50] scores.slice(s![.., 322]) = [2.9802322e-8], shape=[1], strides=[0], layout=CFcf (0xf), const ndim=1 ``` python: ``` frame.shape=(1, 3, 705, 705) frame[0, :, 100, 100]=array([0.85490197, 0.92156863, 1. ], dtype=float32) frame[0, :, 180, 50]=array([0.8980392 , 0.94509804, 1. ], dtype=float32) (1, 77175) max_score=0.9575529 np.where(scores >= max_score)=(array([0]), array([76457])) scores[scores >= 0.9]=array([0.93757355, 0.9528866 , 0.9477465 , 0.95434356, 0.9529081 , 0.9575529 ], dtype=float32) scores[0][322]=5.9604645e-08 ``` ---

On an image where it was supposed to detect something (used for the outputs above), python code correctly returns high-confidence values, while rust code does not. On an image where there is nothing to detect, python code correctly returns only low-confidence values, while rust returns roughly the same values as with the other image.

decahedron1 commented 1 year ago

I think there is an issue with preprocessing.

[src/main.rs:29] input.slice(s![0, .., 100, 100]) = [0.0033525568, 0.0036139947, 0.003921569], shape=[3], strides=[1], layout=CFcf (0xf), const ndim=1

Is the Rust code is dividing by 255 twice? 0.0033525568 * 255 = 0.854901984, same as the Python output.

I notice you transmute (to GRB? correct me if wrong) in Python but not in Rust:

frame = np.transpose(frame, (2, 0, 1))

Try using np.save in Python and ndarray-npy in Rust to save & load the model input to rule out preprocessing differences.

LoipesMas commented 1 year ago

I copied wrong output, my bad (I was trying random things, such as dividing by 255, hoping to get any clues). Here's correct output:

[src/main.rs:27] &input.shape() = [
    1,
    3,
    705,
    705,
]
[src/main.rs:28] input.slice(s![0, .., 100, 100]) = [0.85490197, 0.92156863, 1.0], shape=[3], strides=[1], layout=CFcf (0xf), const ndim=1
[src/main.rs:29] input.slice(s![0, .., 180, 50]) = [0.8980392, 0.94509804, 1.0], shape=[3], strides=[1], layout=CFcf (0xf), const ndim=1
[src/main.rs:42] scores.shape() = [
    1,
    77175,
]
[src/main.rs:48] max_score = Some(
    (
        76654,
        0.0004245937,
    ),
)
[src/main.rs:49] scores.slice(s![.., 322]) = [1.7881393e-7], shape=[1], strides=[0], layout=CFcf (0xf), const ndim=1

That transpose changes the shape from Height x Width x Channels to Channels x Height x Width. In rust the input is already in that format.

I'll try saving and comparing.

LoipesMas commented 1 year ago
write_npy("rust.npy", &input).unwrap();
np.save("python.npy", frame)

Both saved just before running the inference.

In [12]: r = np.load("rust.npy")

In [13]: p = np.load("python.npy")

In [14]: np.array_equal(r,p)
Out[14]: True

Sadly they are the same

decahedron1 commented 1 year ago

I pushed a change a while ago that I think may fix it. The output isn't just precision errors, its completely wrong data. I think the problem is with nshare, and it returning arrays with a non-contiguous memory layout. This would produce an identical array with ndarray-npy (because ndarray uses the stride correctly), but will fail when passed to ONNX Runtime because it expects a contiguous memory layout and thus doesn't accept a stride parameter.

Could you try using ort on the main branch?

ort = { git = "https://github.com/pykeio/ort", branch = "main" }

That transpose changes the shape from Height x Width x Channels to Channels x Height x Width. In rust the input is already in that format.

my brain is not working today 🙃

LoipesMas commented 1 year ago

It worked! Thank you so much. I was already using the main branch, but it didn't check for updates.

I expected this to be image-loading related (since models you mentioned only had text input), but I wouldn't expect it to be this (but I'm new to ndarray and stuff, so I still have a lot to learn). Thanks again, glad you managed to solve it!

decahedron1 commented 1 year ago

Awesome!

Funnily enough, I think this also solves another problem I had with Stable Diffusion's safety checker a month ago... I thought I just did the preprocessing wrong, turns out it was an ort issue! 😅

LoipesMas commented 1 year ago

Hmm, there still seems to be an issue with precision. For one file it's exactly the same, but for others there are small differences (e.g. 0.008985966 in rust, 0.008349687 in python). Max-value is the same, but the error is not negligible. This one looks more like the onnxruntime-rs issue I linked. (And it's not just for really small numbers, 0.9680922 vs 0.96587086)

decahedron1 commented 1 year ago

Are graph optimizations still disabled in Rust? Python may be defaulting to Level3 (aka "all" in official docs) which may cause this problem.

LoipesMas commented 1 year ago

I tried all optimization levels, but it didn't help. Disable might actually be a bit closer to python, because some small values result in 0.0 in python, 2.983123e-8 in rust with Level3, but 0.0 in rust with Disable. But the bigger values don't change at all.

soloist-v commented 1 year ago

I have the same problem. My model is yolov5n.onnx, and the inference result is correct in C++ onnxruntime. The rust ort is wrong. In C++ output is:

5.53997, 7.43146, 10.5765, 12.0833, 6.16908e-06, 0.784669, 0.198309, 10.3879, 8.485, 21.0019, 15.7535, 
8.82149e-06, 0.781702, 0.210636, 18.8944, 7.6956, 26.7782, 13.5962, 2.95043e-06, 0.849125, 0.153245

In Rust output is:

5.762603, 7.143219, 11.919465, 12.539226, 6.2584877e-6, 0.83790576, 0.15439245, 11.427677, 7.2222357, 
20.642508, 13.55667, 2.014637e-5, 0.85138524, 0.14452386, 19.017769, 5.5255632, 18.779839, 9.80006, 
6.4373016e-6, 0.88766444, 0.110286444

Looks like a big error,I checked the input values as follows: C++:

0.454902, 0.454902, 0.462745, 0.462745, 0.458824, 0.447059, 0.454902, 0.462745, 0.454902, 0.45098, 
0.45098, 0.447059, 0.45098, 0.458824, 0.454902, 0.454902, 0.454902, 0.454902, 0.45098, 0.439216, 
0.435294, 0.435294, 0.435294, 0.423529, 0.447059, 0.439216, 0.454902, 0.470588, 0.454902, 0.462745,
 0.470588, 0.466667, 0.462745, 0.45098, 0.454902, 0.45098, 0.439216, 0.447059, 0.458824, 0.454902, 
0.462745, 0.454902, 0.454902, 0.45098, 0.466667, 0.458824, 0.466667, 0.47451, 0.470588, 0.462745

Rust:

0.45490196, 0.45490196, 0.4627451, 0.4627451, 0.45882353, 0.44705883, 0.45490196, 0.4627451, 0.45490196, 
0.4509804, 0.4509804, 0.44705883, 0.4509804, 0.45882353, 0.45490196, 0.45490196, 0.45490196, 0.45490196, 
0.4509804, 0.4392157, 0.43529412, 0.43529412, 0.43529412, 0.42352942, 0.44705883, 0.4392157, 0.45490196, 
0.47058824, 0.45490196, 0.4627451, 0.47058824, 0.46666667, 0.4627451, 0.4509804, 0.45490196, 0.4509804, 
0.4392157, 0.44705883, 0.45882353, 0.45490196, 0.4627451, 0.45490196, 0.45490196, 0.4509804, 0.46666667, 
0.45882353, 0.46666667, 0.4745098, 0.47058824, 0.4627451

The input value looks OK, I cloned the latest version.

soloist-v commented 1 year ago

I fixed it. I replaced ort with onnxrt, then the problem disappeared, please note that I just replaced the session.run function, other codes are the same. ort code:

pub fn detect_f32(&mut self, img: &Mat, res: &mut Vec<DetResult>) -> Result<(), YoloError> {
        let ratio = auto_resize(img, &mut self.input_tmp, &mut self.input_img,
                                self.width as i32, self.height as i32)?;
        let mut input_tensor = self.input_data.clone();
        let input_tensor_raw = unsafe { std::slice::from_raw_parts_mut(input_tensor.as_mut_ptr(), input_tensor.len()) };
        norm_hwc2chw_bgr2rgb(&self.input_img, input_tensor_raw, 255_f32);

        let outputs: Vec<DynOrtTensor<IxDyn>> = self.sess.run([InputTensor::from_array(input_tensor.into_dyn())])?;
        let pred: OrtOwnedTensor<f32, _> = outputs[0].try_extract()?;
        let pred = pred.view();
        let pred = pred.to_slice_memory_order().ok_or("pred.as_slice error")?;

        self.postprocess(pred, res, ratio);
        Ok(())
    }

onnxrt code:

pub fn detect_f32(&mut self, img: &Mat, res: &mut Vec<DetResult>) -> Result<(), YoloError> {
        let ratio = auto_resize(img, &mut self.input_tmp, &mut self.input_img,
                                self.width as i32, self.height as i32)?;
        let mut input_data = vec![0_f32; 3 * 416 * 416].into_boxed_slice();
        norm_hwc2chw_bgr2rgb(&self.input_img, &mut input_data, 255_f32);

        let memory_info = MemoryInfo::new_for_cpu(OrtAllocatorType::OrtDeviceAllocator, OrtMemTypeDefault);
        let input_tensor =
            Value::new_tensor_with_data(&memory_info, &mut input_data, &[1, 3, 416, 416])?;
        let mut output_buffer = vec![0_f32; 10647 * 7].into_boxed_slice();
        let output_tensor = Value::new_tensor_with_data(&memory_info, &mut output_buffer, &[1, 10647, 7])?;
        self.sess.run(None, &self.input_names, &[input_tensor], &self.output_names, &mut [output_tensor])?;
        let pred = &output_buffer;

        self.postprocess(pred, res, ratio);
        Ok(())
    }
decahedron1 commented 1 year ago

I cannot replicate this with SqueezeNet 1.1. I think this problem is exclusive to YOLO models - the linked issue in onnxruntime-rs was with YOLOv3, and this issue involves YOLOv7 and YOLOv5. My outputs for SqueezeNet, Stable Diffusion, BERT, ALBERT, and GPT are all perfectly fine, but it seems only YOLO has problems.

I have no idea why this is happening, my only thought is that it still involves pre/post-processing somehow. ort doesn't modify, copy, or even look at model outputs from ONNX Runtime, it just creates an ndarray using the tensor data provided to it. There's no copying involved, and AFAIK whatever values are in this array are exactly what was given by ONNX Runtime.

@LoipesMas @soloist-v what operating system/execution providers (if any) are you using?


2023-02-12: Maybe try with v1.14.0-beta.0, though I still cannot replicate it with this version.

soloist-v commented 1 year ago

OS: Windows 11 22H2 rustc 1.67.0-nightly (95a3a7277 2022-10-31)

https://github.com/soloist-v/ort/blob/ac6092a34f392fbee2ae73aa5a965e11fdb6e9d4/examples/yolov5.rs

decahedron1 commented 1 year ago

Is there a problem with float16 then? I noticed all the yolov5 ONNX models take float16 inputs/outputs. The only float16 model I am using is Stable Diffusion, but the model has casts from/to float32 at the input/output which could be why I'm not seeing errors.

soloist-v commented 1 year ago

I suspect that the input data may have undergone some conversion or the reading order may have caused the problem, but the C API is not the problem and it's also unlikely that the output is an issue.

soloist-v commented 1 year ago

I initially used a fp32 model, which had accuracy issues, even after switching to f16 the problems persisted. But when I directly called the C API, all the problems disappeared, so I suspect the input underwent an incorrect conversion or the reading order was wrong.

decahedron1 commented 1 year ago

In #18 the only differences I see are using different MemTypes and using pre-allocated output memory. This is almost certainly not what the Python bindings do, so I don't think that's what fixed it. I pushed https://github.com/pykeio/ort/commit/3b74b6293aef5b339668200386776555bea31061 using the "proper" MemTypes now but I don't think it'll make a difference.

ort doesn't read or write any tensor data directly, besides a single clone on line 44 of src/tensor/ort_tensor.rs. Everything else is just reading from pointers.

chertov commented 1 year ago

I had a similar problem in Rust, where the value returned by rust ort differed from the value returned by Python for the same custom ONNX model. I created a demo in pure C++, and the value was the same as in Rust. Then I replaced the image with a set of fixed data, all pixels set to 0.5. Rust, C++, and Python all started giving the same result.

I checked the pixels and noticed that in Rust, the x and y coordinates were swapped. To load the image in Rust, I'm using the image crate:

let input = ndarray::Array::from_shape_fn((1, 256, 256, 3), |(_, x, y, channel)| {
    let pixel = img.get_pixel(y as u32, x as u32); // y, x!!!!!!!!!
    let channels = pixel.channels();
    (channels[channel] as f32) / 255.0 // range [0, 255] -> range [0.0, 1.0]
});
let inputs = [InputTensor::FloatTensor(input.into_dyn())];

After swapping the arguments in img.get_pixel(x, y) to y, x I obtained the correct value.

It is possible that the issue lies elsewhere, but it is worth double-checking the input data to eliminate any potential sources of error.