openai / baselines

OpenAI Baselines: high-quality implementations of reinforcement learning algorithms
MIT License
15.81k stars 4.88k forks source link

cannot train LSTM policy by PPO2 when mujoco env is selected #579

Open takerfume opened 6 years ago

takerfume commented 6 years ago

Hi, I think I discovered bug when I train LSTM Policy by PPO2 when mujoco env is selected.

I run this code. python -m baselines.run --alg=ppo2 --env=Reacher-v2 --num_timesteps=1e6 --network=lstm --nminibatches=2 --num_env=4

and I get this error.

Traceback (most recent call last): File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/runpy.py", line 85, in _run_code exec(code, runglobals) File "/home/isi/yoshida/baselines/baselines/run.py", line 235, in main() File "/home/isi/yoshida/baselines/baselines/run.py", line 214, in main model, = train(args, extra_args) File "/home/isi/yoshida/baselines/baselines/run.py", line 69, in train alg_kwargs File "/home/isi/yoshida/baselines/baselines/ppo2/ppo2.py", line 245, in learn obs, returns, masks, actions, values, neglogpacs, states, epinfos = runner.run() #pylint: disable=E0632 File "/home/isi/yoshida/baselines/baselines/ppo2/ppo2.py", line 104, in run actions, values, self.states, neglogpacs = self.model.step(self.obs, S=self.states, M=self.dones) File "/home/isi/yoshida/baselines/baselines/common/policies.py", line 89, in step a, v, state, neglogp = self._evaluate([self.action, self.vf, self.state, self.neglogp], observation, extra_feed) File "/home/isi/yoshida/baselines/baselines/common/policies.py", line 71, in _evaluate return sess.run(variables, feed_dict) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 877, in run run_metadata_ptr) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1100, in _run feed_dict_tensor, options, run_metadata) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1272, in _do_run run_metadata) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1291, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'ppo2_model/vf/Placeholder_1' with dtype float and shape [1,256] [[Node: ppo2_model/vf/Placeholder_1 = Placeholder[dtype=DT_FLOAT, shape=[1,256], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'ppo2_model/vf/Placeholder_1', defined at: File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/runpy.py", line 193, in _run_module_as_main "main", mod_spec) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/runpy.py", line 85, in _run_code exec(code, runglobals) File "/home/isi/yoshida/baselines/baselines/run.py", line 235, in main() File "/home/isi/yoshida/baselines/baselines/run.py", line 214, in main model, = train(args, extra_args) File "/home/isi/yoshida/baselines/baselines/run.py", line 69, in train alg_kwargs File "/home/isi/yoshida/baselines/baselines/ppo2/ppo2.py", line 230, in learn model = make_model() File "/home/isi/yoshida/baselines/baselines/ppo2/ppo2.py", line 229, in max_grad_norm=max_grad_norm) File "/home/isi/yoshida/baselines/baselines/ppo2/ppo2.py", line 25, in init act_model = policy(nbatch_act, 1, sess) File "/home/isi/yoshida/baselines/baselines/common/policies.py", line 159, in policy_fn vflatent, = _v_net(encoded_x) File "/home/isi/yoshida/baselines/baselines/common/models.py", line 105, in network_fn S = tf.placeholder(tf.float32, [nenv, 2nlstm]) #states File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/ops/array_ops.py", line 1735, in placeholder return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/ops/gen_array_ops.py", line 4925, in placeholder "Placeholder", dtype=dtype, shape=shape, name=name) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 454, in new_func return func(args, kwargs) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3155, in create_op op_def=op_def) File "/home/isi/yoshida/anaconda3/envs/baselines/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1717, in init self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'ppo2_model/vf/Placeholder_1' with dtype float and shape [1,256] [[Node: ppo2_model/vf/Placeholder_1 = Placeholder[dtype=DT_FLOAT, shape=[1,256], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

How can I train LSTM Policy by PPO2 in mujoco?

For your information, I can sucessfully train LSTM policy by PPO2 in PongNoFrameskip-v4.

pzhokhov commented 6 years ago

this is not related to mujoco per se, but rather to a fact that mujoco uses value_network='copy' by default; and when creating a copy of a network, a new set of placeholders is created for lstm state and mask. As a workaround I'd suggest using --value_network=shared flag (this way, policy and value networks will have a shared lstm cell with the same placeholders). I am looking into solving this issue in a more principled way.

takerfume commented 6 years ago

Thank you! I understand that error means I didn't feed a value for placefolder of value net which is created by 'copying' policy net.

I run this command. And sucessfully train LSTM policy! python -m baselines.run --alg=ppo2 --network=lstm --num_timesteps=1e6 --env=Reacher-v2 --num_env=4 --nminibatches=2 --value_network=shared

zacwellmer commented 6 years ago

@pzhokhov Have you found that the 'copy' value network(not sharing parameters) produces better results on mujoco? Do you have any guess as to why this would be the case?

pzhokhov commented 6 years ago

generally not sharing parameters makes training more stable (less sensitive to hyperparameters such as value function coefficient in the training objective or learning rate) because two different objectives do not compete with each other, whereas sharing parameters allows for faster learning (when it works). For image-based observations (and convolutional layers) we use parameter sharing , because otherwise both value function approximator and policy would have to learn good visual features, and that may take too many samples. Mujoco has simulator state-based observations that do not require much of feature learning; and not sharing parameters gets us training that works on decently on all environments without much hyperparameter tuning.

eastskykang commented 5 years ago

@pzhokhov is there any update so far? Still 'copy' value network is not supported for lstm with ppo2.