cbfinn / gps

Guided Policy Search
http://rll.berkeley.edu/gps/
Other
594 stars 239 forks source link

Caffe to TF Hyper parameter option #52

Closed ashwinreddy closed 7 years ago

ashwinreddy commented 7 years ago

I had difficulties getting Caffe to work with the generalization MuJoCo example. I found a pull request from @wmontgomery4 to switch to TensorFlow in the hyperparams file and so I added an easy way to switch between libraries by adding another parameter. I've also added a line to the MJC examples documentation to make it clear how to access this.

ashwinreddy commented 7 years ago

Hello,

I tried running the 2nd MJC example again, however I received an error from TF: ValueError: Shapes (26, 40) and (40, 7) are not compatible.

In order to test that this was not a setup error on my part, I tried rewinding to different commits to see where the error kicked in. I believe it is from commit a624a7227a4c29743a23f8fa6b6bb0779bf79ace, most likely due to changes related to the network in the example_tf_network file.

However, it's not clear to me what to change in the network parameters in order to fix this error.

cbfinn commented 7 years ago

Sorry about that. I just pushed a change that should fix the issue. Let me know if it works.

ashwinreddy commented 7 years ago

It is able to at least interact with the MJC environment, however I've encountered another error related to the hyperparams:

Traceback (most recent call last):
  File "/usr/lib/python2.7/threading.py", line 810, in __bootstrap_inner
    self.run()
  File "/usr/lib/python2.7/threading.py", line 763, in run
    self.__target(*self.__args, **self.__kwargs)
  File "python/gps/gps_main.py", line 398, in <lambda>
    target=lambda: gps.run(itr_load=resume_training_itr)
  File "python/gps/gps_main.py", line 76, in run
    self._take_iteration(itr, traj_sample_lists)
  File "python/gps/gps_main.py", line 209, in _take_iteration
    self.algorithm.iteration(sample_lists)
  File "python/gps/algorithm/algorithm_badmm.py", line 60, in iteration
    self._update_policy(inner_itr)
  File "python/gps/algorithm/algorithm_badmm.py", line 143, in _update_policy
    self.policy_opt.update(obs_data, tgt_mu, tgt_prc, tgt_wt)
  File "python/gps/algorithm/policy_opt/policy_opt_tf.py", line 155, in update
    if self._hyperparams['fc_only_iterations'] > 0:
KeyError: 'fc_only_iterations'

Looks like this hasn't been defined in the file. What should its value be?

cbfinn commented 7 years ago

Just pushed a fix to add a default value of 0. Thanks for catching that!

ashwinreddy commented 7 years ago

After training the 3000 iterations, TF gives this:

Traceback (most recent call last):
  File "/usr/lib/python2.7/threading.py", line 810, in __bootstrap_inner
    self.run()
  File "/usr/lib/python2.7/threading.py", line 763, in run
    self.__target(*self.__args, **self.__kwargs)
  File "python/gps/gps_main.py", line 398, in <lambda>
    target=lambda: gps.run(itr_load=resume_training_itr)
  File "python/gps/gps_main.py", line 76, in run
    self._take_iteration(itr, traj_sample_lists)
  File "python/gps/gps_main.py", line 209, in _take_iteration
    self.algorithm.iteration(sample_lists)
  File "python/gps/algorithm/algorithm_badmm.py", line 60, in iteration
    self._update_policy(inner_itr)
  File "python/gps/algorithm/algorithm_badmm.py", line 143, in _update_policy
    self.policy_opt.update(obs_data, tgt_mu, tgt_prc, tgt_wt)
  File "python/gps/algorithm/policy_opt/policy_opt_tf.py", line 194, in update
    self.feat_vals = self.solver.get_var_values(self.sess, self.feat_op, feed_dict, num_values, self.batch_size)
  File "python/gps/algorithm/policy_opt/tf_utils.py", line 145, in get_var_values
    batch_vals = sess.run(var, batch_dict)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 372, in run
    run_metadata_ptr)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 584, in _run
    processed_fetches = self._process_fetches(fetches)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 540, in _process_fetches
    % (subfetch, fetch, type(subfetch), str(e)))
TypeError: Fetch argument None of None has invalid type <type 'NoneType'>, must be a string or Tensor. (Can not convert a NoneType into a Tensor or Operation.)

Not quite sure how to fix this one.

cbfinn commented 7 years ago

Just pushed a fix. Sorry for the bugs.

ashwinreddy commented 7 years ago

Great; looks like it works now. Thanks again!

cbfinn commented 7 years ago

Great!