DLR-RM / stable-baselines3

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
https://stable-baselines3.readthedocs.io
MIT License
8.38k stars 1.61k forks source link

Exporting MultiInputActorCriticPolicy as ONNX #1873

Open MaximCamilleri opened 3 months ago

MaximCamilleri commented 3 months ago

❓ Question

Hi,

I am looking into the use of ONNX with SB3. I have tested 2 models (A2C and PPO) on a custom environment using a MultiInputActorCriticPolicy. The observation space of the environment is of type dict. So far I have not been able to produce an onnaxable policy.

In the documentation the words The following examples are for MlpPolicy only, and are general examples can be found. Is it possible to export a model of my type to ONNX? and if so would it be possible to provide an example?

Thanks

Checklist

araffin commented 3 months ago

Hello, what have you tried so far? and what errors did you encounter?

Please provide a minimal and working code example (see link in issue template for what that means).

MaximCamilleri commented 3 months ago

Hello, thanks for your response.

I have tried a couple of things so far. First I tried converting my model into an onnxable policy using the method shown in the documentation. My code is as follows:

class OnnxablePolicy(th.nn.Module):
    def __init__(self, policy):
        super(OnnxablePolicy2, self).__init__()
        self.policy = policy

    def forward(self, input):
        return self.policy(input)

model = PPO.load("Models/ppo.zip")
onnx_policy = OnnxablePolicy(model.policy)

th.onnx.export(
    onnx_policy,
    obs_dict,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

To get the dummy input which I am here calling obs_dict, I used the following code snippet:

obs = env.reset()
obs_dict = {}
for key in obs.keys():
    obs_dict[key] = th.from_numpy(np.array([obs[key]])).float()

This creates an input with the same structure as the observation space after common.preprocessing.preprocess_obs is run. The error I was getting at this point is: TypeError: OnnxablePolicy2.forward() missing 1 required positional argument: 'input'

I also tried the approach seen here, and created the following code:

class OnnxablePolicy(th.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxablePolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, input):
        action_hidden = value_hidden = self.extractor(input)
        return self.action_net(action_hidden), self.value_net(value_hidden)

onnx_policy = OnnxablePolicy(model.policy.features_extractor, model.policy.action_net, model.policy.value_net)

th.onnx.export(
    onnx_policy,
    obs_dict,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

Which resulted in the same error as before.

Finally I tried using the policy as is:

model = PPO.load("Models/ppo.zip")
obs = env.reset()
th.onnx.export(
    model.policy,
    obs,
    "ONNX/ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

This seemingly got me the furthest, producing the new error:

[110](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:110)     assert isinstance(obs, Dict), f"Expected dict, got {type(obs)}"
    [111](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:111)     preprocessed_obs = {}
    [112](file:////.venv/Lib/site-packages/stable_baselines3/common/preprocessing.py:112)     for key, _obs in obs.items():
AssertionError: Expected dict, got <class 'torch.Tensor'>
araffin commented 3 months ago

I gave it a try but this one seems to be a bit hard, you probably need to use the experimental onnx export from pytorch (using dynamo). The thing that got me further was to pass (obs_dict, {}) as observation, otherwise pytorch try to use it as keyword arguments.

my current attempt (the export seems to work but the loading doesn't :/)


import torch as th
from typing import Tuple
import gymnasium as gym

from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy

import onnx
import onnxruntime as ort
import numpy as np

class OnnxableSB3Policy(th.nn.Module):
    def __init__(self, policy: BasePolicy):
        super().__init__()
        self.policy = policy

    def forward(self, observation):
        print(observation)
        return observation["a"]
        # NOTE: Preprocessing is included, but postprocessing
        # (clipping/inscaling actions) is not,
        # If needed, you also need to transpose the images so that they are channel first
        # use deterministic=False if you want to export the stochastic policy
        return self.policy._predict(observation, deterministic=True)

class Custom(gym.Env):
    def __init__(self):
        super().__init__()
        self.observation_space = gym.spaces.Dict(
            {
                "a": gym.spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32),
                # "b": gym.spaces.Discrete(5),
            }
        )
        self.action_space = gym.spaces.Discrete(2)

    def reset(self, seed=None):
        return self.observation_space.sample(), {}

    def step(self, action):
        return self.observation_space.sample(), 0.0, False, False, {}

env = Custom()
obs, _ = env.reset()
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MultiInputPolicy", env).save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")

onnx_policy = OnnxableSB3Policy(model.policy)

observation_size = model.observation_space.shape
# Add batch dimension
dummy_input = {
    # "a": np.array(obs["a"])[np.newaxis, ...],
    "a": np.array(obs["a"]),
    # "b": np.array(obs["b"])[np.newaxis, ...],
}
dummy_input_tensor = {
    "a": th.as_tensor(dummy_input["a"]),
    # "b": th.as_tensor(dummy_input["b"]),
}

print(model.predict(dummy_input, deterministic=True))

th.onnx.export(
    onnx_policy,
    args=(dummy_input_tensor, {}),
    f="my_ppo_model.onnx",
    opset_version=17,
    input_names=["input"],
)

##### Load and test with onnx

onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

observation = dummy_input.copy()
ort_sess = ort.InferenceSession(onnx_path)

# print(ort_sess.get_inputs()[0].name)
# print(ort_sess.get_inputs())

output = ort_sess.run(None, {"input": observation})

print(output)

# Check that the predictions are the same
# with th.no_grad():
#     print(model.policy(th.as_tensor(observation), deterministic=True))
araffin commented 3 months ago

" Due to design differences, input/output format between PyTorch model and exported ONNX model are often not the same. E.g., None is allowed for PyTorch model, but are not supported by ONNX. Nested constructs of tensors are allowed for PyTorch model, but only flattened tensors are supported by ONNX, etc."

from https://pytorch.org/docs/stable/onnx_dynamo.html#torch.onnx.ONNXProgram.adapt_torch_inputs_to_onnx

NickLucche commented 2 months ago

HI all, I wouldn't really export the sampling procedure to onnx here (''self.policy._predict(observation, deterministic=True) "), but rather have the network output the raw logits and implement the sampling as a postprocessing step. A consistent export procedure would be a nice feature to add to the framework :)