microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.68k stars 2.93k forks source link

ONNX does not support Dirichlet distribution? #15016

Open littlesulley opened 1 year ago

littlesulley commented 1 year ago

Describe the issue

I am training a neural network using Pytorch where I leverage the Beta distribution to sample actions. The model can be exported well with operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH (but not other options). But when i start an inference session from the exported model, i have the error onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from onnx/PPO_MLP.onnx failed:Fatal error: prim:PythonOp(-1) is not a registered function/op, which seems to imply that ONNX does not support dirichlet distribution.

Here is my neural network definition:

class Agent_Beta(nn.Module):
    def __init__(self, n_dim):
        super(Agent_Beta, self).__init__()
        self.critic = nn.Sequential(
            layer_init(nn.Linear(n_dim, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 32)),
            nn.Tanh(),
            layer_init(nn.Linear(32, 1), std=1.0)
        )
        self.actor_main = nn.Sequential(
            layer_init(nn.Linear(n_dim, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 128)),
            nn.Tanh(),
            layer_init(nn.Linear(128, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 32)),
            nn.Tanh(),
        )
        self.actor_alpha = nn.Sequential(
            layer_init(nn.Linear(32, 12), std=0.01)
        )
        self.actor_beta = nn.Sequential(
            layer_init(nn.Linear(32, 12), std=0.01)
        )
    def forward(self, x, action=None):
        action, _, _, _ = self.get_action_and_value(x, action)
        return action

    def get_value(self, x):
        return self.critic(x)   

    def evaluate(self, x):
        probs = self.get_probs(x)
        action = probs.sample().detach().cpu().numpy()
        return self.scale_action(action)

    def get_probs(self, x):
        main = self.actor_main(x)
        alpha = F.softplus(self.actor_alpha(main)) + 1.0 
        beta = F.softplus(self.actor_beta(main)) + 1.0 
        probs = Beta(alpha, beta)  
        return probs

    def get_action_and_value(self, x, action=None):
        main = self.actor_main(x)
        alpha = F.softplus(self.actor_alpha(main)) + 1.0 
        beta = F.softplus(self.actor_beta(main)) + 1.0   
        probs = Beta(alpha, beta)  
        if action is None:            
            action = probs.sample()  
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x) 

Here is the model graph (divided into two parts): 1

================================

2

=========================

3

To reproduce

None.

Urgency

No response

Platform

Windows

OS Version

Win11

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.13

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.7

skottmckay commented 1 year ago

As per https://pytorch.org/docs/stable/onnx.html

image

You're telling the pytorch exporter that it's okay to use operators that are not part of ONNX. That generates a model in ONNX format, but the nodes involving the non-ONNX operators need custom handling.

@thiagocrepaldi is this something that could be handled by the aten fallback, or the user would need to create a custom operator?

littlesulley commented 1 year ago

@skottmckay Thanks for your reply! I'm new to onnx and pytorch, so forgive me if i ask stupidly trivial questions ;)
something that could be handled by the aten fallback you mean setting the OperatorExportTypes as ONNX_ATEN_FALLBACK? If it is, i would say unfortunately it won't work in my case. An error will raise saying something like RuntimeError: ONNX export failed: Couldn't export Python operator _Dirichlet. Creating a custom operator means i should implement it in c++ and export it to onnx format as described in this tutorial?

andife commented 4 months ago

Did you implement dirichlet separately at the end? or are there now solutions to export torch.distributions.dirichlet to onnx?