PaddlePaddle / PARL

A high-performance distributed training framework for Reinforcement Learning
https://parl.readthedocs.io/
Apache License 2.0
3.27k stars 822 forks source link

【PARL】PPO算法示例并行化遇到 TypeError: cannot pickle 'ParamBase' object #832

Open luoxiao21024 opened 2 years ago

luoxiao21024 commented 2 years ago

序列化有哪些注意事项呀,复杂的对象一般要如何处理,在agent上直接注释就变成这样了

TomorrowIsAnOtherDay commented 2 years ago

建议使用python原生的数据类型组装,你可以发一下你的“复杂对象”代码吗?我们帮你看看

luoxiao21024 commented 2 years ago

万分感谢,就是用了examples里面的示例想做并行化的尝试

mujoco_agent

@parl.remote_class class MujocoAgent(parl.Agent):

简单的加了注释 https://github.com/PaddlePaddle/PARL/blob/develop/examples/PPO/train.py

parl.connect("localhost:6006") agent = MujocoAgent(algorithm)

Traceback (most recent call last): File "D:\PycharmProjects\PPO\train.py", line 168, in main() File "D:\PycharmProjects\PPO\train.py", line 84, in main agent = MujocoAgent(algorithm) File "D:\ProgramData\Anaconda3\lib\site-packages\parl\remote\proxy_wrapper.py", line 64, in init self._xparl_remote_wrapper_obj = remote_wrapper(*args, **kwargs) File "D:\ProgramData\Anaconda3\lib\site-packages\parl\remote\remote_wrapper.py", line 98, in init cloudpickle.dumps([args, kwargs]), File "D:\ProgramData\Anaconda3\lib\site-packages\cloudpickle\cloudpickle_fast.py", line 73, in dumps cp.dump(obj) File "D:\ProgramData\Anaconda3\lib\site-packages\cloudpickle\cloudpickle_fast.py", line 563, in dump return Pickler.dump(self, obj) TypeError: cannot pickle 'ParamBase' object

需要改动哪里呢?

TomorrowIsAnOtherDay commented 2 years ago

是这样的,agent的初始化依赖于algorithm,你的agent加了修饰符之后,变成虚拟的agent了(不在本地运行),但是你的algorithm还在本地初始化。目前algorithm是比较复杂的类,暂不支持直接序列化的。 你可以参考我们A2C的写法,写一个actor类,在类里面初始化model、algorithm、agent。

luoxiao21024 commented 2 years ago

感谢,了解了,是否是因为algorithm使用了torch实现导致的呢? 还有如果只想要并行化模型的部分,而不涉及PPO算法,是可以的嘛? 简单修改后,比如:

https://github.com/PaddlePaddle/PARL/blob/develop/examples/PPO/train.py#L74

parl.connect("localhost:6006") actor = Actor(env.observation_space.shape[0], env.action_space.n) critic = Critic(env.observation_space.shape[0]) model = MujocoModel(env.observation_space.shape[0], env.action_space.n, actor, critic)

https://github.com/PaddlePaddle/PARL/blob/develop/examples/PPO/mujoco_agent.py

@parl.remote_class class MujocoModel(parl.Model):

@parl.remote_class class Actor(parl.Model):

@parl.remote_class class Critic(parl.Model):

这样貌似依然会报错,还是cannot pickle 'ParamBase' object

TomorrowIsAnOtherDay commented 2 years ago

你试试只修饰MujcoModel看看?

luoxiao21024 commented 2 years ago

这样的话会报这个错误,只在MujocoModel类上做了修饰 NameError: name 'Actor' is not defined

Actor和critic分成单独文件解决了,但是使用起来还是有些问题,麻烦您看下

Traceback (most recent call last): File "D:\ProgramData\Anaconda3\lib\site-packages\parl\remote\communication.py", line 71, in dumps_argument ret = serialize([args, kwargs]) File "D:\ProgramData\Anaconda3\lib\site-packages\parl\remote\communication.py", line 51, in serialize = lambda data: pyarrow.serialize(data, context=context).to_buffer() File "pyarrow\serialization.pxi", line 388, in pyarrow.lib.serialize File "pyarrow\serialization.pxi", line 399, in pyarrow.lib._serialize File "pyarrow\serialization.pxi", line 162, in pyarrow.lib.SerializationContext._serialize_callback File "D:\ProgramData\Anaconda3\lib\site-packages\parl\remote\communication.py", line 31, in _serialize_serializable return {"type": type(obj), "data": obj.dict} AttributeError: 'Tensor' object has no attribute 'dict'

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "D:\PycharmProjects\PPO\train.py", line 174, in main() File "D:\PycharmProjects\PPO\train.py", line 104, in main value, action, action_log_prob = agent.sample(rollouts.obs[step]) File "D:\PycharmProjects\PPO\mujoco_agent.py", line 47, in sample value, action, action_log_probs = self.alg.sample(obs) File "D:\ProgramData\Anaconda3\lib\site-packages\parl\algorithms\paddle\ppo.py", line 108, in sample value = self.model.value(obs) File "D:\ProgramData\Anaconda3\lib\site-packages\parl\remote\remote_wrapper.py", line 190, in wrapper data = dumps_argument(*args, **kwargs) File "D:\ProgramData\Anaconda3\lib\site-packages\parl\remote\communication.py", line 73, in dumps_argument raise SerializeError(e) parl.utils.exceptions.SerializeError: [PARL Utils Error]: Serialize error, you may have provided an object that cannot be serialized by pyarrow. Detailed error: 'Tensor' object has no attribute 'dict'

TomorrowIsAnOtherDay commented 2 years ago

我大概知道问题在哪里了。这和PARL目前的代码分发实现逻辑相关,分发了所有的代码文件,但是对于要修饰符所在的代码文件进行了特殊处理:只分发修饰符修饰的类(具体原因好像是为了初始化更快,我记不大清楚了)。这个问题我们后续会修复,感谢指出。 现在建议参考actor.py的方式实现智能体分发。