PaddlePaddle / PARL

A high-performance distributed training framework for Reinforcement Learning
https://parl.readthedocs.io/
Apache License 2.0
3.22k stars 817 forks source link

在examples/PPO中,如何保存模型,用于推理? #1086

Closed wangyexiang closed 1 year ago

wangyexiang commented 1 year ago

在跑例子中的PPO代码时,跑完后如何保存模型,用于推理? 参照例子中的DDPG代码,尝试了如下方式,报错AssertionError: model needs to implement forward method. 第一种:

save_inference_path = './inference_model'
input_shapes = [[None, obs_space.shape[0]]]
input_dtypes = ['float32']
agent.save_inference_model(save_inference_path, input_shapes, input_dtypes)

第二种:

save_inference_path = './inference_model'
input_shapes = [[None, obs_space.shape[0]]]
input_dtypes = ['float32']
agent.save_inference_model(save_inference_path, input_shapes, input_dtypes, model)
rical730 commented 1 year ago

Hi,如报错提示所示,被保存的模型需要实现 forward 方法,因为PPO example是针对训练设计的,model仅提供了valuepolicy方法,不是针对评估推理设计的,没有forward函数,需要用户自定义选择推理流程。

因此,你需要在PPO的model中新增一个forward方法,比如你想保存policy的推理流程,可以增加如下代码

    # 新增 forward 方法,用于指定想要保存的推理过程
    def forward(self, obs):
        return self.policy(obs)

还有一个issue的问题和你的问题类似,可供参考:https://github.com/PaddlePaddle/PARL/issues/1028

wangyexiang commented 1 year ago

Hi,如报错提示所示,被保存的模型需要实现 forward 方法,因为PPO example是针对训练设计的,model仅提供了valuepolicy方法,不是针对评估推理设计的,没有forward函数,需要用户自定义选择推理流程。

因此,你需要在PPO的model中新增一个forward方法,比如你想保存policy的推理流程,可以增加如下代码

    # 新增 forward 方法,用于指定想要保存的推理过程
    def forward(self, obs):
        return self.policy(obs)

还有一个issue的问题和你的问题类似,可供参考:#1028

感谢您的解答,祝您工作顺利