microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.05k stars 2.83k forks source link

Cannot register custom DML operator #16347

Open carsonswope opened 1 year ago

carsonswope commented 1 year ago

Describe the issue

Hello,

I'm attempting to register a custom operator for the DmlExectionProvider, and getting a crash from GraphPartitioner.cpp.

The message is: Assertion failed: createInfo != nullptr, file <onnruntime source>\core\providers\dml\DmlExecutionProvider\src\GraphPartitioner.cpp, line 166.

struct MyCustopOpKernel { ... };
struct MyCustomOp : Ort::CustomOpBase<MyCustomOp, MyCustopOpKernel > { ... }

Ort::CustomOpDomain custom_op_domain("my_domain");
MyCustomOp custom_op{};
custom_op_domain.Add(&custom_op);

Ort::SessionOptions sessionOptions;
sessionOptions.Add(custom_op_domain);
// ... other DML-specific session options configurations..
ortDmlApi->SessionOptionsAppendExecutionProvider_DML1(sessionOptions, dmlDevice, dmlQueue);
// crash occurs on this line:
ortSession = std::make_unique<Ort::Session>(ortEnv, modelData.data(), modelData.size(), sessionOptions);

I know that onnxruntime is able to recognize my op declaration, because if the domain name or the op name don't match what's in the model, I get a helpful message letting me know that the op/function was not found. I when correctly configured, it gets past that point, only to crash deeper in the Ort::Session initialization. Based on the where the crash happens in the ort source code, it seems like I need to access the RegisterDmlOperators function from DmlExecutionProvider.h file. But its not clear that this is part of the public API? It's not exported from dml_provider_factory.h, which I had been using as my entry point to the onnxruntime DirectML functionality. So I'm wondering if implementing custom DML operator is even supported? And if so, what am I missing here?

To reproduce

I am happy to upload a more complete repro but I wanted to get a confirmation that what I'm trying to do is even supported first :)

Urgency

:)

Platform

Windows

OS Version

10

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

latest (1.15)

ONNX Runtime API

C++

Architecture

X64

Execution Provider

DirectML

Execution Provider Library Version

No response

fdwr commented 1 year ago

So I'm wondering if implementing custom DML operator is even supported?

I'm pretty sure this approach of using Ort::CustomOpBase is not blessed or tested in relation to the DML EP, but I thought there was a way to achieve custom ORT operators using DML (or rather D3D) resources using the custom MLOperatorAuthor API because I'm sure I saw some samples of WinML using it (@martinb35 does that sound correct to you?). @jeffbloo, since you know from memory, do you know any good samples of MLOperatorAuthor usage?

RandySheriffH commented 1 year ago

@carsonswope : we are looking to support DML custom ops, hence a few infrastructural changes are to be made. Do u mind sharing a complete repo?

carsonswope commented 1 year ago

Hi @RandySheriffH, I will try to get one up here for you in the next couple weeks. I ended up just forking the repo and implementing my operator directly into onnxruntime, following the recent implementation of grid_sample as an example.

In the meantime, the first crash that stopped me had to do with the custom op not being registered with the IMLOperatorRegistry for DML. It seems like this happens automatically w/ a CUDA or CPU custom operator, but maybe that step is skipped with the DML custom operator.

I'm guessing that there's also another piece missing in the current API, which is that the D3D12 command buffer and buffer resources need to somehow be made available to the custom operator when the Compute function is called.

varunsh-xilinx commented 2 weeks ago

@RandySheriffH is there any update on this? I was able to use this branch to make some progress after resolving the conflicts but it would be great for this to officially work. I also ran into some issues calling the native operator from the custom op.