Open littlesulley opened 1 year ago
As per https://pytorch.org/docs/stable/onnx.html
You're telling the pytorch exporter that it's okay to use operators that are not part of ONNX. That generates a model in ONNX format, but the nodes involving the non-ONNX operators need custom handling.
@thiagocrepaldi is this something that could be handled by the aten fallback, or the user would need to create a custom operator?
@skottmckay Thanks for your reply! I'm new to onnx and pytorch, so forgive me if i ask stupidly trivial questions ;)
something that could be handled by the aten fallback
you mean setting the OperatorExportTypes
as ONNX_ATEN_FALLBACK
? If it is, i would say unfortunately it won't work in my case. An error will raise saying something like RuntimeError: ONNX export failed: Couldn't export Python operator _Dirichlet
.
Creating a custom operator means i should implement it in c++ and export it to onnx format as described in this tutorial?
Did you implement dirichlet separately at the end? or are there now solutions to export torch.distributions.dirichlet to onnx?
Describe the issue
I am training a neural network using Pytorch where I leverage the Beta distribution to sample actions. The model can be exported well with
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH
(but not other options). But when i start an inference session from the exported model, i have the erroronnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from onnx/PPO_MLP.onnx failed:Fatal error: prim:PythonOp(-1) is not a registered function/op
, which seems to imply that ONNX does not support dirichlet distribution.Here is my neural network definition:
Here is the model graph (divided into two parts):
================================
=========================
To reproduce
None.
Urgency
No response
Platform
Windows
OS Version
Win11
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.13
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
CUDA 11.7