microsoft / onnxruntime-extensions

onnxruntime-extensions: A specialized pre- and post- processing library for ONNX Runtime
MIT License
323 stars 84 forks source link

Using onnxruntime custom op that are written with libtorch library #631

Closed natangoldgm closed 8 months ago

natangoldgm commented 9 months ago

Hi, In my team we already have a flow of deploying neural networks on our target. We are using onnx to compile our models but the inference on target is not using onnxruntime but using a different package. For these models we have a set of custom operators that are written in libtorch and used in the pytorch inference. For our target we have different implementation utilizing the dsp of the target. I want to use the custom operators that already written in libtorch in onnxruntime. Is it possible? I don't want to re-write my entire set of custom ops to ort tensors because it will take a long time. Thanks, Natan.

wenbingl commented 8 months ago

onnxruntime-extensions supports to use a native C++ function to be a ONNX custom op. But if you are able to wrap libtorch custom ops to be some C++ function, you are good to go.

natangoldgm commented 8 months ago

@wenbingl in this case the functions signature will be float* instead of ort tensors?

wenbingl commented 8 months ago

@wenbingl in this case the functions signature will be float* instead of ort tensors?

here is the example: https://github.com/microsoft/onnxruntime-extensions/blob/44e494bab471a942aa3302fca16d3bef07744dae/operators/math/negpos.hpp#L13. Because of tensor ownership, a pure floating data point is insufficient but you can have functions with floating point, then wrap them with ortc::tensor to register them into ORT.

natangoldgm commented 8 months ago

Ok, I took your recommendation. Now trying to create the interface of the ort custom op. I'm trying to implement the function RegisterCustomOps. I have segmentation fault when trying to create CustomOpDomain. Any idea what can caused it? This is the function I implemented so far:

`OrtStatus ORT_API_CALL RegisterCustomOps(OrtSessionOptions options, const OrtApiBase* api) {

printf("register custom ops"); OrtStatus* result = nullptr; printf("init api"); Ort::InitApi(api->GetApi(ORT_API_VERSION)); printf("creating session options"); Ort::UnownedSessionOptions session_options(options); ORT_TRY { printf("creating domain"); Ort::CustomOpDomain domain{c_OpDomain}; printf("adding domain"); session_options.Add(domain); printf("move domain"); AddOrtCustomOpDomainToContainer(std::move(domain)); } ORT_CATCH(const std::exception& e) { ORT_HANDLE_EXCEPTION([&]() { printf("caught exception"); Ort::Status status{e}; result = status.release(); }); }

printf("\n\nfinished subscribing"); return result; }`

natangoldgm commented 8 months ago

Thanks. Will create different ticket for my current issue