thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
8.02k stars 1.13k forks source link

[Feature Request] Export policy in ONNX/TF format #314

Closed zhujl1991 closed 3 years ago

zhujl1991 commented 3 years ago

In order to serve the policy in our serving infrastructure, which only supports ONNX/TF, I'm trying to export the policy in ONNX with torch.onnx.export(). But I got this error: RuntimeError: Only tuples, lists and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type Batch.

Looks like the reason is that the args argument of torch.onnx.export() only supports Tensor-like variables https://discuss.pytorch.org/t/torch-onnx-export-fails/72418. But the input/output of the policies is the type of Batch. This may need some non-trivial change to Tianshou.

I think the support of exporting policies in ONNX/TF format is extremely important given that most serving infrastructures in the industry are designed for ONNX/TF models (actually, in order to convert PyTorch model to TF model, we usually need to convert it to ONNX format first). This feature will allow Tianshou to be way more widely used in the industry.

Do we have any plans to add this feature?

Trinkle23897 commented 3 years ago

Great request! I'll definitely try to support this feature.

zhujl1991 commented 3 years ago

Great request! I'll definitely try to support this feature.

Do you have any idea about how to do this? I'm also looking into it. Looks like we have to use a Tensor to replace the argument of type Batch, or find a way to wrap the Batch in a Tensor somehow, in foward() funciton in each policy.

Trinkle23897 commented 3 years ago

I'm thinking. Is it possible to directly replace the Batch to dict for onnx deployment?

zhujl1991 commented 3 years ago

I'm thinking. Is it possible to directly replace the Batch to dict for onnx deployment?

What do you mean by "onnx deployment"? You mean pass in a dict instead of Batch into forward() so that we can call torch.onnx.export()?

Trinkle23897 commented 3 years ago

no i mean can we just use the original parameter and just change the Batch in the code in dict and test if it works

Trinkle23897 commented 3 years ago

On second thought, I think the easiest way is to directly call torch.onnx.export(self.model) (if you use DQN) or torch.onnx.export(self.actor) (if you use other algos). There's no need to use the training code in policy class, so basically you only need to write (copy) the forward function's logic.

AlessandroZavoli commented 2 years ago

Could you provide a minimal working example? I'm experiencing the same issue and I can't find any solution

Trinkle23897 commented 2 years ago

You need some manual work: something like

policy = torch.load(...)
torch.onnx.export(policy.actor)

once you've done that, you can port code only related to network (instead of policy) with the onnx exported model

arita37 commented 1 year ago

Pytorch 2.0 may simplify the compilation process ?

arita37 commented 1 year ago

Yes, separating the training part from the forward pass is very good idea for design….