rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.84k stars 309 forks source link

In the example, an error is reported when changing to GRUPolicy #2288

Closed shallowdream48 closed 3 years ago

shallowdream48 commented 3 years ago

In the example "trpo_gym_tf_cartpole", an error is reported when changing to GRUPolicy, how can I solve this problem?

2021-06-17 15:35:32 | [trpo_gym_tf_cartpole] epoch #0 | computing descent direction

InvalidArgumentError Traceback (most recent call last) ~/anaconda3/envs/garage/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, args) 1374 try: -> 1375 return fn(args) 1376 except errors.OpError as e:

~/anaconda3/envs/garage/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata) 1357 # Ensure any changes to the graph are reflected in the runtime. -> 1358 self._extend_graph() 1359 return self._call_tf_sessionrun(options, feed_dict, fetch_list,

~/anaconda3/envs/garage/lib/python3.6/site-packages/tensorflow/python/client/session.py in _extend_graph(self) 1397 with self._graph._session_run_lock(): # pylint: disable=protected-access -> 1398 tf_session.ExtendSession(self._session) 1399

InvalidArgumentError: Node 'optimize/hx_plain/gradients_hx_plain/ConjugateGradientOptimizer/update_opt_mean_kl/gradients_constraint/policy_1/gru/rnn_2/while_grad/policy_1/gru/rnn_2/while_grad_grad/ConjugateGradientOptimizer/update_opt_mean_kl/gradients_constraint/policy_1/gru/rnn_2/while_grad/policy_1/gru/rnn_2/while_grad_grad': Connecting to invalid output 78 of source node ConjugateGradientOptimizer/update_opt_mean_kl/gradients_constraint/policy_1/gru/rnn_2/while_grad/policy_1/gru/rnn_2/while_grad which has 78 outputs. Try using tf.compat.v1.experimental.output_all_intermediates(True).

During handling of the above exception, another exception occurred:

InvalidArgumentError Traceback (most recent call last)

in 35 36 ---> 37 trpo_gym_tf_cartpole() ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/experiment/experiment.py in __call__(self, *args, **kwargs) 367 else: 368 ctxt = self._make_context(self._get_options(*args), **kwargs) --> 369 result = self.function(ctxt, **kwargs) 370 logger.remove_all() 371 logger.pop_prefix() in trpo_gym_tf_cartpole(ctxt, seed) 32 33 trainer.setup(algo, env) ---> 34 trainer.train(n_epochs=120, batch_size=4000) 35 36 ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/trainer.py in train(self, n_epochs, batch_size, plot, store_episodes, pause_for_plot) 400 dump_json(summary_file, self) 401 --> 402 average_return = self._algo.train(self) 403 self._shutdown_worker() 404 ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/tf/algos/npo.py in train(self, trainer) 177 trainer.step_episode = trainer.obtain_episodes(trainer.step_itr) 178 last_return = self._train_once(trainer.step_itr, --> 179 trainer.step_episode) 180 trainer.step_itr += 1 181 ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/tf/algos/npo.py in _train_once(self, itr, episodes) 210 211 logger.log('Optimizing policy...') --> 212 self._optimize_policy(episodes, baselines) 213 214 return np.mean(undiscounted_returns) ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/tf/algos/npo.py in _optimize_policy(self, episodes, baselines) 229 policy_kl_before = self._f_policy_kl(*policy_opt_input_values) 230 logger.log('Optimizing') --> 231 self._optimizer.optimize(policy_opt_input_values) 232 logger.log('Computing KL after') 233 policy_kl = self._f_policy_kl(*policy_opt_input_values) ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/tf/optimizers/conjugate_gradient_optimizer.py in optimize(self, inputs, extra_inputs, subsample_grouped_inputs, name) 469 logger.log('computing descent direction') 470 hx = self._hvp_approach.build_eval(subsample_inputs + extra_inputs) --> 471 descent_direction = _cg(hx, flat_g, cg_iters=self._cg_iters) 472 473 initial_step_size = np.sqrt( ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/tf/optimizers/conjugate_gradient_optimizer.py in _cg(f_Ax, b, cg_iters, residual_tol) 543 544 for _ in range(cg_iters): --> 545 z = f_Ax(p) 546 v = rdotr / p.dot(z) 547 x += v * p ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/tf/optimizers/conjugate_gradient_optimizer.py in _eval(v) 75 xs = tuple(self._target.flat_to_params(v)) 76 ret = _sliced_fn(self._hvp_fun['f_hx_plain'], self._num_slices)( ---> 77 inputs, xs) + self._reg_coeff * v 78 return ret 79 ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/tf/optimizers/conjugate_gradient_optimizer.py in _sliced_f(sliced_inputs, non_sliced_inputs) 587 for start in range(0, n_paths, slice_size): 588 inputs_slice = [v[start:start + slice_size] for v in sliced_inputs] --> 589 slice_ret_vals = f(*(inputs_slice + non_sliced_inputs)) 590 591 if not isinstance(slice_ret_vals, (tuple, list)): ~/anaconda3/envs/garage/lib/python3.6/site-packages/garage/tf/_functions.py in _run(*input_vals) 22 # pylint: disable=missing-return-doc, missing-return-type-doc 23 sess = tf.compat.v1.get_default_session() ---> 24 return sess.run(outputs, feed_dict=dict(list(zip(inputs, input_vals)))) 25 26 return _run ~/anaconda3/envs/garage/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata) 966 try: 967 result = self._run(None, fetches, feed_dict, options_ptr, --> 968 run_metadata_ptr) 969 if run_metadata: 970 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr) ~/anaconda3/envs/garage/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 1189 if final_fetches or final_targets or (handle and feed_dict_tensor): 1190 results = self._do_run(handle, final_targets, final_fetches, -> 1191 feed_dict_tensor, options, run_metadata) 1192 else: 1193 results = [] ~/anaconda3/envs/garage/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata) 1367 if handle is None: 1368 return self._do_call(_run_fn, feeds, fetches, targets, options, -> 1369 run_metadata) 1370 else: 1371 return self._do_call(_prun_fn, handle, feeds, fetches) ~/anaconda3/envs/garage/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args) 1392 '\nsession_config.graph_options.rewrite_options.' 1393 'disable_meta_optimizer = True') -> 1394 raise type(e)(node_def, op, message) 1395 1396 def _extend_graph(self): InvalidArgumentError: Node 'optimize/hx_plain/gradients_hx_plain/ConjugateGradientOptimizer/update_opt_mean_kl/gradients_constraint/policy_1/gru/rnn_2/while_grad/policy_1/gru/rnn_2/while_grad_grad/ConjugateGradientOptimizer/update_opt_mean_kl/gradients_constraint/policy_1/gru/rnn_2/while_grad/policy_1/gru/rnn_2/while_grad_grad': Connecting to invalid output 78 of source node ConjugateGradientOptimizer/update_opt_mean_kl/gradients_constraint/policy_1/gru/rnn_2/while_grad/policy_1/gru/rnn_2/while_grad which has 78 outputs. Try using tf.compat.v1.experimental.output_all_intermediates(True).
krzentner commented 3 years ago

Were you able to solve this issue?

shallowdream48 commented 3 years ago

yes,I have solved this issue