nikhilbarhate99 / PPO-PyTorch

Minimal implementation of clipped objective Proximal Policy Optimization (PPO) in PyTorch
MIT License
1.57k stars 332 forks source link

Export as ONNX Model #25

Closed CesMak closed 4 years ago

CesMak commented 4 years ago

Hey,

Thanks for sharing this awesome code!

I would like to export my result also as onnx model. However I have no idea how to use it then... currently it did not work for me:

This is how I export it:

    torch_out = torch.onnx._export(ppo.policy, input_vector, path+".onnx",  export_params=True)

To get this to work I had to implement a forward as well:


class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, n_latent_var):
        super(ActorCritic, self).__init__()
        #... same as your code

    def forward(self, state_input):
        return torch.tensor(self.act(state_input, None))

    def act(self, state, memory):
        if type(state) is np.ndarray:
            state = torch.from_numpy(state).float().to(device)
        action_probs = self.action_layer(state)
        # here make a filter for only possible actions!
        #probs = probs * memory.leagalCards
        dist = Categorical(action_probs)

        action = dist.sample()

        if memory is not None:
            memory.states.append(state)
            memory.actions.append(action)
            memory.logprobs.append(dist.log_prob(action))

        return action.item()

Now I tried to use my onnx model like this:

But it returns always the same action :(


def getOnnxAction(path, x):
        '''Input:
        x:      240x1 list binary values
        path    *.onnx (with correct model)'''
        ort_session = onnxruntime.InferenceSession(path)
        ort_inputs  = {ort_session.get_inputs()[0].name: np.asarray(x, dtype=np.float32)}
        ort_outs    = ort_session.run(None, ort_inputs)
        return np.asarray(ort_outs)[0]

Any ideas what is going wrong here?

CesMak commented 4 years ago

I found the solution.

torch.onnx.export(ppo_test.policy_old.action_layer, torch.rand(240), path+".onnx")

You can close this issue.