DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.69k stars 1.65k forks source link

Current ONNX opset doesn't support StableBaselines3 natively, requires creating a wrapper class. #383

Closed batu closed 2 years ago

batu commented 3 years ago

I am interested in using stable-baselines to train an agent, and then export it through ONNX to put it inside the Unity engine via Barracuda. I was hoping to write up the documentation too!

Unfortunately the opset 9 or opset 12 in ONNX doesn't seem to support converting trained policies.

RuntimeError: Exporting the operator broadcast_tensors to ONNX opset version 9 is not supported.
Please open a bug to request ONNX export support for the missing operator.

While the broadcast_tensor isn't something explicitly called in the codebase it potentially might be related to using torch.distributions. Unfortunately, this seems to be an open issue since 2019 November, so I am pessimistic about it being solved soon.

While very unlikely, do you think there might be a way around this? Either way, I wanted to raise this issue so the team is aware.

Checklist

JadenTravnik commented 3 years ago

I'm not sure what parts of the implementation of the policy make it so that it isn't supported by the opset (I think you're right about torch.distributions) but the mlp_extractor, action_net, and the value_net modules of the ActorCriticPolicy are all "onnxable" as they don't include the broadcast_tensors operator. So one can do the following:

class OnnxablePolicy(torch.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxablePolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, input):
        action_hidden, value_hidden = self.extractor(input)
        return (self.action_net(action_hidden), self.value_net(value_hidden))

Then one can use the OnnxablePolicy to do the actual onnx export. A similar thing can be done for the ActorCriticCnnPolicy but minding that the features_extractor only outputs a single hidden value tensor which is consumed by both the value and action net. Its a little bit of surgery but given this workaround perhaps one of the maintainers has a better approach.

batu commented 3 years ago

Thank you! I will look into this.

Are you saying that I should train with the OnnxablePolicy, or after the fact instantiate the OnnxablePolicy class with the pretrained networks?

It is true that I haven't tried exporting the individual action_net but instead worked at the model.policy level.

JadenTravnik commented 3 years ago

The "after the fact" instantiate option. I'm still not sure where exactly the broadcast is but I guess its not in these 3 modules :p

batu commented 3 years ago

Omg I think it works! Managed to export an onnx file without crashing, and the input-output shapes seem good.

Thank you so so much! Now to hope the exported onnx file also works.

Will write all these down when I get a reliable pipeline working.

PS: There are three underscores in your init. For a solid 5 minutes, I doubted I knew how classes worked :D

JadenTravnik commented 3 years ago

Hehe, glad to hear (and sorry for the typo :p). If you have the time, I know that a PR integrating your pipeline would be super appreciated by the community (myself included).

batu commented 3 years ago

What you suggested is working! Managed to put the SB3 trained agent into Unity with no python dependencies.

Will write up a notebook showing how to export it into ONNX soon, (but it is basically what @JadenTravnik typed up!)

Would people be interested in the Barracuda Unity side of it?

nav

Miffyli commented 3 years ago

Thanks to @JadenTravnik for this valuable piece of advice :).

@batu: Feel free to make a PR which updates the documentation regarding ONNX. It does not have to lengthy, maybe link to the example you mentioned + link to this issue + outlining what needs to be done.

araffin commented 3 years ago

Hello, I'll re-open this issue as I think it would a useful addition to the doc (https://stable-baselines3.readthedocs.io/en/master/guide/export.html#export-to-onnx).

I would also be interested in knowing which part exactly break the export? (probably try to delete variables until it works?). The current solution only works for MlpPolicy when no preprocessing is needed (so works only with Box observation space).

For a more complete solution, you should preprocess the input, see https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/policies.py#L110

Another thing not mentioned is that we do a lot of magic in the predict() method, notably re-organizing image channels automatically and a returning the correct shape (depending on the input shape): https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/policies.py#L238

Egiob commented 3 years ago

Hello, I'm having the same error but with the SAC algorithm, when trying to export model.policy.actor to ONNX

RuntimeError: Exporting the operator broadcast_tensors to ONNX opset version 9 is not supported. Please open a bug to request ONNX export support for the missing operator.

It seems to me that the forward process in the class Actoris much more intricated with the distributions thing than in the class ActorCriticPolicy, but maybe i'm wrong. Anyway I can't figure out how to make a custom "Onnxable" class like @JadenTravnik proposed for solving the first issue.

Do you think that such a simple fix is possible in this case ?

batu commented 3 years ago

Yes, I believe the initial .flatten() causes the problem. The following works for me.

class SACOnnxablePolicy(torch.nn.Module):
    def __init__(self,  actor):
        super(SACOnnxablePolicy, self).__init__()

        # Removing the flatten layer because it can't be onnxed
        self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)

    def forward(self, input):
        return self.actor(input)

new_model = SACOnnxablePolicy(loaded_model.policy.actor)
Egiob commented 3 years ago

Thanks @batu, it seems to solve the problem 😄

Miffyli commented 3 years ago

@batu @Egiob FYI we would be happy to take PRs on adding comments on exporting models to ONNX, as this seems to be a common thing people do :)

batu commented 3 years ago

will do it this weekend!

shehrozeee commented 2 years ago

can you and example for exporting the DQN model from the lunar-lander example ?

Miffyli commented 2 years ago

Hey. Unfortunately we do not do custom tech support requests (we do not have time for that). The things shoooould be somewhat clear from the documentation (and if not, let us know! We are happy to approve any PRs that update the documentation).

araffin commented 2 years ago

can you and example for exporting the DQN model from the lunar-lander example ?

I highly recommend you to learn more about DQN and read the code, but this is all you need to export to ONNX (given the initial example in the doc). As DQN is not using any probably distribution, the export is straightforward:

from stable_baselines3 import DQN
import torch

model = DQN("MlpPolicy", "LunarLander-v2")
model.policy.to("cpu")
onnxable_model = model.policy

observation_size = model.observation_space.shape[0]

dummy_input = torch.randn(1, observation_size)
onnx_path = "dqn_model.onnx"
torch.onnx.export(onnxable_model, dummy_input, onnx_path, opset_version=9)

##### Load and test with onnx

import onnx
import onnxruntime as ort
import numpy as np

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

observation = np.zeros((1, observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action = ort_sess.run(None, {'obs': observation})
gilzamir18 commented 1 year ago

I'm not sure what parts of the implementation of the policy make it so that it isn't supported by the opset (I think you're right about torch.distributions) but the mlp_extractor, action_net, and the value_net modules of the ActorCriticPolicy are all "onnxable" as they don't include the broadcast_tensors operator. So one can do the following:

class OnnxablePolicy(torch.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxablePolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, input):
        action_hidden, value_hidden = self.extractor(input)
        return (self.action_net(action_hidden), self.value_net(value_hidden))

Then one can use the OnnxablePolicy to do the actual onnx export. A similar thing can be done for the ActorCriticCnnPolicy but minding that the features_extractor only outputs a single hidden value tensor which is consumed by both the value and action net. Its a little bit of surgery but given this workaround perhaps one of the maintainers has a better approach.

I made a example using a MultiInputPolicy with SAC:

import torch as th
from torch import nn
from stable_baselines3 import SAC
import gym

class OnnxablePolicy(th.nn.Module):
    def __init__(self, actor: th.nn.Module, extractors: nn.ModuleDict):
        super().__init__()
        # Removing the flatten layer because it can't be onnxed

        self.extractors = extractors

        self.actor = th.nn.Sequential(
            actor.latent_pi,
            actor.mu,
            # For gSDE
            # th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean),
            # Squash the output
            th.nn.Tanh(),
        )

    def forward(self, observations):
        input_tensor = self.extractors(observations)
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        return self.actor(input_tensor)

# Example: model = SAC("MlpPolicy", "Pendulum-v1")
model = SAC.load("sac_ai4u.zip", device="cpu")

onnxable_model = OnnxablePolicy(model.policy.actor, model.policy.actor.features_extractor)

dummy_inputs = {k: th.randn(1, *obs.shape) for k, obs in model.observation_space.items()}

th.onnx.export(
    onnxable_model,
    (dummy_inputs, {}),
    "my_sac_actor.onnx",
    opset_version=9,
    input_names=["input"],
)
araffin commented 1 year ago

I made a example using a MultiInputPolicy with SAC:

thanks =), but it seems that you are passing randomly initialized features extractor instead of the trained ones: model.policy.make_features_extractor() instead of model.policy.actor.features_extractor

gilzamir18 commented 1 year ago

I made a example using a MultiInputPolicy with SAC:

thanks =), but it seems that you are passing randomly initialized features extractor instead of the trained ones: model.policy.make_features_extractor() instead of model.policy.actor.features_extractor

True, I edited my comment to fix that.