webonnx / wonnx

A WebGPU-accelerated ONNX inference run-time written 100% in Rust, ready for native and the web
Other
1.61k stars 59 forks source link

Support data types other than f32 #48

Closed pixelspark closed 2 years ago

pixelspark commented 2 years ago

This PR (based on my earlier PR #45 that implemented the DAG) adds support for data types other than f32. Right now I have implemented support for i32 inputs, but others (such as f16) should be fairly easy to add. The current version will always cast outputs to Vec<f32> but of course the output type should depend on the actual output type in the model.

Notably, BERT and some other models appear to use int64 inputs - working with i64 in WGSL however is not (yet?) possible. We should therefore probably re-interpret these as int32.

Some implementation notes:

See tests/arithmetic.rs for an example of integer addition.

pixelspark commented 2 years ago

@haixuanTao I rebased this on top of your workspace PR, this one should be ready to go now.

haixuanTao commented 2 years ago

Ok I think we can merge this branch.

pixelspark commented 2 years ago

Ok I think we can merge this branch.

Yes, agreed. We might improve the API in a later PR (e.g. the builder pattern or something else)