sonos / tract

Tiny, no-nonsense, self-contained, Tensorflow and ONNX inference
Other
2.24k stars 214 forks source link

Creating executable bundle for ONNX model #393

Closed antimora closed 3 years ago

antimora commented 4 years ago

Is it possible to create a standalone bundle from ONNX model? Currently the examples listed require ONNX file be loaded. It would be awesome if one can compile the model statically. Would have been great if ONNX file could be loaded by a rust macro and compiled into static code optimized.

kali commented 4 years ago

I've never tried, but theoretically at least, you could use model_for_read, and std's include_bytes!. Have you tried it ?

sonovice commented 4 years ago

I've tried it a few weeks ago using include_bytes! - worked like a charm.

antimora commented 4 years ago

Thanks, @kali , @sonovice. This works but I was hoping there wouldn't be a need to parse and optimize ONNX every time the inference engine is loaded. There are additional advantages compiling statically. Theoretically only required subset of operators could be included in the bundle, instead preloading all runtime operators. Also rust compiler could optimize executable further.

Could you please point me to the right direction if I were to try to build the graph myself? Does each node, such as conv, contain computation logic? Or is there some executor that takes the graph and computes the nodes? In which case, I suppose I cannot compute Conv node independently from the the rest of the framework.

kali commented 4 years ago

tract does not compile models to native code. tract is basically an optimizing interpreter. Operators are the interpreter primitives.

If you prefer to define your model as code, the best way it to include tract-core, create a TypedModel and call wire_node for each node (and it may actually not be very hard to write a generator from tract-core for this code). I do not want to make any commitment on API stability here, and core operators are not documented, but they are not evolving very quickly.

See the — very limited — example here: https://docs.rs/tract-core/0.11.2/tract_core/index.html . The CLI can dump a model in tract-core form to use as a starting point. It may not be trivial to condition constant tensors the right way (for instance, convolution filters have to be prepacked), but maybe I can help.

Another option is to use the OPL serialization: tract-OPL is an extended version of NNEF, and it's closer to tract-core than ONNX. All operators in tract-core can dump themselves in OPL, plus quite a few from ONNX. You should be able to shed tract-hir and tract-onnx from your binary, along with the protobuf parser and a few more. And the model loading is much faster as we start from a form very close to tract-core. The CLI can dump a network to OPL.

I need to write a few recipes for the CLI...

antimora commented 4 years ago

@kali Thank you. I'll look into your suggestions. Need more time to accustom the APIs and code.

hexgnu commented 3 years ago

Hey @kali I am wanting a similar thing to @antimora. Basically I have an ONNX file that I want to load statically if at all possible.

What I have tried following your advice above:

My questions: Is there a magic combination of parameters to get the tract cli to dump to tract-core format? My model is LightGBM -> hummingbird -> ONNX, should I maybe try a different model while avoiding ONNX?

kali commented 3 years ago

@hexgnu 0/ is your example working with ONNX at least ? 1/ tract-core is not a "dumpable" format, just a in-memory representation, but tract-core is pretty close to NNEF semantically. Almost any network in tract-core can be translated to NNEF with OPL extensions. So there is no command line trick to convert to tract-core: but tract model.onnx -i .... dump --nnef-tar model.nnef.tar will load you model and convert it to tract-core, then dump it to NNEF. 2/ OneHot is supposed to be supported in tract-core now, with nnef/opl serialisation. OPL extensions are opt-in: you will need to add

3/ we have support in tract-core for OneHot now, so i think LightGBM models should work out of the box (activating the extensions when needed) 4/ we also have support for trees encoded with ONNX TreeEnsembleClassigier (I think xgboost generates them like that) as an onnx opl extension. For this one, you will need to include --nnef-tract-onnx on the command line, and call with_tract_onnx() on the nnef framework. You will also need to include tract-onnx-opl in your build.

Hope this helps. Tell me how it goes :)

hexgnu commented 3 years ago

Hey @kali thanks for the advice.

  1. It does work with ONNX just fine :). I think my big issue is that since I am compiling this to WASM it's loading the ONNX file every single time the page refreshes in the browser. My previous model which is straight feed forward neural net was roughly 50ms to load and predict and doing the ONNX load and inference is roughly 1500ms. Unfortunately speed is essential with this problem.

  2. I am able to dump to NNEF using your suggestions above which is great! Although I am finding the performance degrade even more, although I feel like I have something miswired since it's returning bogus numbers now. So I will continue to debug it.

My feeling looking at the graph.nnef inside the tarball is that it wouldn't be too hard to write a code generator that wrote out all of this in static rust code. I'll see what I can accomplish and report back. Pointers are definitely welcome and thank you.

-Matt

kali commented 3 years ago

Just to make sure: you still need to call into_optimized() on the model loaded from NNEF. It should be order of magnitude faster than into_optimized() over the ONNX model.

hexgnu commented 3 years ago

Yea I am using into_optimized() on the nnef model.

Basically what I have done is

tract -O -i 1x19xf32 -f onnx ~/git/foreshadow/psychic-octopus-extension/src/model.onnx --nnef-tract-onnx  dump --nnef-tar model.tar

Then load that up using

let mut cursor = Cursor::new(include_bytes!("./model.tar") as &[u8]);

let nnef = tract_nnef::nnef();
let model =
         nnef.model_for_read(&mut cursor)?.into_optimized()?.into_runnable()?;

Weirdly this is taking forever (2000ms instead of 1500ms with onnx) and returning what looks like an int version of my models response. Not entirely sure what I'm doing wrong here.

I also tried to add with_onnx but couldn't figure out how to make that work with tract_onnx_opl.

Thanks for helping me out with this!

kali commented 3 years ago
antimora commented 3 years ago

@hexgnu were you able to make any progress on this issue? I am curious about your approach. Have you attempted to write your own for the 15 operations you mentioned? Have you come across some reference material for operations? I couldn't find good simple examples. Many open source implementations are specific to their architecture, i.e. PyTorch.

hexgnu commented 3 years ago

@antimora so I ended up doing a few optimizations of the code using a lazy_static and that seemed to help quite a bit... for our purposes it seemed to work. It still would be awesome to push it down a layer because we are building this for WASM and avoiding cold-starts is very useful for us.

One thing I did find on the first iteration of what I built in pytorch (feed forward only) was that fundamentally building sigmoid, softmax, argmax etc type functions is very simple in pure rust. The problem honestly is when I started using things like one-hots or decision trees.

antimora commented 3 years ago

@hexgnu thanks for your answers. I am curious which library do you use for your WASM math functions, such as log? Do you use something like MicroMath crate that supports [no_std]?

BTW, you might want to check out Pytorch Glow project. They have facility to compile ONNX to static code. https://github.com/pytorch/glow/blob/master/docs/AOT.md

hexgnu commented 3 years ago

@antimora I used straight calculations in rust and relied on the wasm compiler to compile them to wasm, so nothing special really. I will check out the glow project looks great! I do think that with the right compiler a lot of these calculations can be compiled down into bytecode fairly simply it's just that it hasn't been a huge focus of the community.

kali commented 3 years ago

Going to close this.

benmkw commented 2 years ago

Could this issue be reopened?

Additionally to the mentioned portability and perf and size improvements I also think that the API could be strongly typed when a specific onnx file is specified which would make running the model more robust against user errors.

A possible data point for performance is NN512 https://jonatron.github.io/test_nn512/ which seems promising.

tract is basically an optimizing interpreter. Operators are the interpreter primitives.

I was wondering what is the reason for this setup/ why is this preferable?