sonos / tract

Tiny, no-nonsense, self-contained, Tensorflow and ONNX inference
Other
2.19k stars 210 forks source link

Possible wrong computations while outputing a probability vector #732

Closed francoisWeber closed 2 years ago

francoisWeber commented 2 years ago

Hi there,

TL;DR: I noticed a wrong computation while outputing a prediction probability vector for a classification tree

Disclaimer: I'm new to Rust. I tried to re-use a snippet of Tract to make inference from on ONNX file thats embeds a simple binary classification tree. The output of the model is:

The model comes from sklearn-onnx. Here is a snippet of code to reproduce such a model:

from sklearn.tree import DecisionTreeClassifier
import sklearn
import numpy as np

import skl2onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from onnxruntime import InferenceSession
import onnx

# seed stuff
np.random.seed(seed=1)

# check versions
assert skl2onnx.__version__ == "1.11.2"
assert sklearn.__version__ == "0.24.2"
assert onnx.version.version == "1.11.0"

# generate random X and random binary ground-truth
X = np.random.randn(1000, 5)  # 1000 obs having 5 features each
y = np.random.choice(2, size=(1000))  # 1000 boolean ground-truth
tree = DecisionTreeClassifier(max_leaf_nodes=42)
tree.fit(X, y)

# now convert to ONNX
onnx_options = {"zipmap": False, "output_class_labels": False}
initial_type = [("X", FloatTensorType([None, X.shape[1]]))]
model_proto = convert_sklearn(
    tree,
    initial_types=initial_type,
    options={id(tree): onnx_options},
)

# Now draw a prediction based on a randomly chosen X
X_test = np.array(
    [[0.23556215, 2.848049, 0.38786993, -1.50578322, -0.17152989]], dtype=np.float32
)
sess = InferenceSession(model_proto.SerializeToString())
inference_result = sess.run(None, {"X": X_test})
assert np.allclose(inference_result[1], [[0.78571427, 0.21428572]])

# save it to load it in Rust
onnx.save(model_proto, "./dummy_tree.onnx")

So as you can see, the prediction inference_result has a 2nd dimension containing the prediction probabilities [0.78571427, 0.21428572].

Now to load it back into a Rust program, I used the following Rust x Tract x ONNX helper:

use std::boxed::Box;

use std::io::{self, Error};
use std::ops::Deref;
use tract_ndarray::Array;

use tract_onnx::prelude::tract_ndarray::ShapeError;
use tract_onnx::prelude::*;
use tract_onnx::tract_core::anyhow::Error as TractError;

type OnnxModel = SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;

pub enum InferenceError {
    TractError(TractError),
    Error(Error),
    ShapeError(ShapeError),
}

impl From<io::Error> for InferenceError {
    fn from(err: io::Error) -> Self {
        InferenceError::Error(err)
    }
}

impl From<TractError> for InferenceError {
    fn from(err: TractError) -> Self {
        InferenceError::TractError(err)
    }
}

impl From<ShapeError> for InferenceError {
    fn from(err: ShapeError) -> Self {
        InferenceError::ShapeError(err)
    }
}

pub struct ONNXRunnerRaw {
    onnx_runnable_model: OnnxModel,
}

impl ONNXRunnerRaw {
    pub fn init(model_path: &str, input_fact_vector: Vec<i64>) -> Result<Self, TractError> {
        // load onnx model and initialized it.
        let onnx_model = tract_onnx::onnx()
            .model_for_path(&model_path)?
            .with_input_fact(
                0,
                InferenceFact::dt_shape(f32::datum_type(), input_fact_vector),
            )?
            .into_optimized()?
            .into_runnable()?;
        // load json label to id file and convert it to hashmap id_to_label for for converting prediction result.
        Ok(ONNXRunnerRaw {
            onnx_runnable_model: onnx_model,
        })
    }

    pub fn classify(&self, input: Vec<f32>) -> Result<f32, TractError> {
        // println!("{:?}", input);
        let tensor = Array::from_shape_vec((1, input.len()), input)?.into();
        let result = self.onnx_runnable_model.run(tvec!(tensor))?;
        dbg!(&result);
        let res = &*result[0];
        let _proba = result[1].deref().to_array_view::<f32>().unwrap();
        let proba = _proba.as_slice().unwrap();
        dbg!(&proba);
        println!(
            "The above result should sum to 1 because it's supposed to be a probability vector"
        );
        res.cast_to_scalar::<f32>()
        // println!("{:?}", a);
    }
}

And I call this ONNXRunnerRaw into the following main.rs file:

use onnx_runner::ONNXRunnerRaw;

fn main() {
    println!("hello");
    let input_fact_vector: Vec<i64> = vec![1, 5];
    let model_path = "./assets/dummy_tree.onnx";
    let runner = match ONNXRunnerRaw::init(model_path, input_fact_vector) {
        Ok(runner) => runner,
        Err(e) => panic!("{:?}", e),
    };
    let input_vector: Vec<f32> = vec![0.23556215, 2.848049, 0.38786993, -1.50578322, -0.17152989];
    let result = runner.classify(input_vector).unwrap();
    println!("{}", result);
}

And it outputs the following:

hello
[onnx_runner/src/lib.rs:62] &result = [
    1,I64 0,
    1,2,F32 0.25, 0,
]
[onnx_runner/src/lib.rs:66] &proba = [
    0.21,
    0.0,
]
The above result should sum to 1 because it's supposed to be a probability vector
0

So we retrieved the pythonic probability of 0.214 of beeing in class 1 but the complimentary probability in the rust version is 0.0 !

I think it's a bug the way Tract handles the output of the ONNX model. I hope this feedback is helpful

François Weber

kali commented 2 years ago

@francoisWeber thanks for the report. This does look like a bug.

I tried to run the python pseudo-training script, but could not manage to get the right versions of some dependencies.

Can you please provide me the .onnx model plus the input and expected output of the model as a .npz file as shown in https://github.com/sonos/tract/blob/main/doc/cli-recipe.md#running-a-test-case ? You need to make input names and output names in the network match the tensor names in the io.npz, and you can "test" the test with cargo command line.

francoisWeber commented 2 years ago

Hi @kali Here is the dummy_tree.onnx.zip file (sorry I had to zip it to make Git happy ...). I also generated a io.npz.zip as described in your tutorial. With these two assets, the output of the command tract -v --input-bundle onnx_runner/assets/io.npz onnx_runner/assets/dummy_tree.onnx -O run --assert-output-bundle onnx_runner/assets/io.npz is the following:

[2022-06-06T12:01:17.271420671Z INFO  tract] Resource usage init: vsz:26910720 rsz:6307840 rszmax:6307840
[2022-06-06T12:01:17.272891005Z INFO  tract] Resource usage loaded framework (onnx): vsz:26910720 rsz:6307840 rszmax:6307840
[2022-06-06T12:01:17.280994755Z INFO  tract] Resource usage proto model loaded: vsz:26910720 rsz:6307840 rszmax:6307840
[2022-06-06T12:01:17.281151880Z WARN  tract_onnx::model] ONNX operator for your model is 15, tract is tested against operator set 9, 10, 11 and 12 only. Your model may still work so this is not a hard fail.
[2022-06-06T12:01:17.283553796Z INFO  tract::params] Model Fs("onnx_runner/assets/dummy_tree.onnx") loaded
[2022-06-06T12:01:17.283756838Z INFO  tract] Resource usage model loaded: vsz:26910720 rsz:6307840 rszmax:6307840
[2022-06-06T12:01:17.295279796Z INFO  tract::params] Will stop at optimize
[2022-06-06T12:01:17.295314755Z INFO  tract::params] Running 'analyse'
[2022-06-06T12:01:17.296104046Z INFO  tract] Resource usage after analyse: vsz:26910720 rsz:10252288 rszmax:10252288
[2022-06-06T12:01:17.296119963Z INFO  tract::params] Running 'incorporate'
[2022-06-06T12:01:17.296178171Z INFO  tract] Resource usage after incorporate: vsz:26910720 rsz:10252288 rszmax:10252288
[2022-06-06T12:01:17.296192713Z INFO  tract::params] Running 'type'
[2022-06-06T12:01:17.296785005Z INFO  tract] Resource usage after type: vsz:26910720 rsz:10252288 rszmax:10252288
[2022-06-06T12:01:17.296802255Z INFO  tract::params] Running 'declutter'
[2022-06-06T12:01:17.297420338Z INFO  tract] Resource usage after declutter: vsz:26910720 rsz:10252288 rszmax:10252288
[2022-06-06T12:01:17.297449463Z INFO  tract::params] Running 'before-optimize'
[2022-06-06T12:01:17.297471546Z INFO  tract] Resource usage after before-optimize: vsz:26910720 rsz:10252288 rszmax:10252288
[2022-06-06T12:01:17.297474713Z INFO  tract::params] Running 'optimize'
[2022-06-06T12:01:17.297513255Z INFO  tract] Resource usage after optimize: vsz:26910720 rsz:10252288 rszmax:10252288
[2022-06-06T12:01:17.297518505Z INFO  tract::params] Model ready
[2022-06-06T12:01:17.297547880Z INFO  tract] Resource usage model ready: vsz:26910720 rsz:10252288 rszmax:10252288
[2022-06-06T12:01:17.297652588Z INFO  tract::tensor] Using fixed input for input called X (1 turn(s))
[2022-06-06T12:01:17.298345046Z INFO  tract::utils] Checked output #0, ok.
[2022-06-06T12:01:17.298428796Z ERROR tract] Checking output 1 (expected 1,2,F32 0.78571427, 0.21428572, got 1,2,F32 0.21428572, 0

    Caused by:
        Mismatch at [0, 0] 0.78571427 != 0.21428572

Notice that $0.21428572 = 1 - 0.78571427$

Hope this will help you understand the bug :)

EDIT: I just tried to mitigate the warning about the ONNX's OpSet by setting my target opset to 12 and the problem remains.

kali commented 2 years ago

Thanks for taking the time! Having a look at the test case right now.

kali commented 2 years ago

Thanks. I took a dive inside ONNX runtime code to see what they are doing, and... wow, that's not nice. They are doing a lot of post-processing, specifically tailored for the "binary" case (when there are two categories), while the higher categories number are more or less left alone. ONNX documentation does not describe this, so I don't know what tract is supposed to do. I could try and mimick ONNXRuntime behaviour, but I would prefer to understand where this is coming, because ONNXRuntime code is pretty complicated. Do you know where it comes from ? SciKit maybe ?

francoisWeber commented 2 years ago

Do you know where it comes from ? SciKit maybe ?

If you are talking about my .onnx, then yes : it comes from a SciKitLearn sklearn.tree.DecisionTreeClassifier converted to ONNX through sklearn-onnx which is part of the official ONNX framework: https://onnx.ai/sklearn-onnx/ .

Does the binary-tailored post-processing you're talking about refer to the computation of the complimentary probability in the binary case ? 'cause it's only tractable in the binary case ... just a clue ?

kali commented 2 years ago

Hey @francoisWeber, I did... something on #734, branch name isfix-732. Do you want to check that it gives the results as you are expecting them ?

francoisWeber commented 2 years ago

Using your fix-732 it outputs:

[onnx_runner/src/lib.rs:62] &result = [
    1,I64 0,
    1,2,F32 0.78571427, 0.21428572,
]
[onnx_runner/src/lib.rs:66] &proba = [
    0.78571427,
    0.21428572,
]

Well done, @kali 😎

kali commented 2 years ago

@francoisWeber thanks for checking it out. I will have to be cautious in merging, because it's a breaking change... I need to look around/think what other breaking changes I need to pass.

kali commented 2 years ago

@francoisWeber FYI , fix released as part of 0.17.0