cbfinn / gps

Guided Policy Search
http://rll.berkeley.edu/gps/
Other
593 stars 239 forks source link

some problems when i use `--resume iter_N` #123

Open dujinyu opened 4 years ago

dujinyu commented 4 years ago

when i use python python/gps/gps_main.py box2d_pointmass_pigps_example --resume 5, some errors occur.

when i check the implementation, i found the reason in policy_opt_tf.py.

    def save_model(self, fname):
        print "self.sess:", self.sess
        print "fname generated by NamedTemporaryFile:", fname 
        LOGGER.debug('Saving model to: %s', fname)
        path = self.saver.save(self.sess, fname, write_meta_graph=False) 
        print "file path to the saver:", path
        # sys.exit(("exit from policy_opt_tf.py/save_model"))

    def restore_model(self, fname):
        self.saver.restore(self.sess, fname)
        LOGGER.debug('Restoring model from: %s', fname)

    # For pickling.
    def __getstate__(self):
        with tempfile.NamedTemporaryFile('w+b', delete=True) as f:
            self.save_model(f.name) # TODO - is this implemented.
            f.seek(0)
            with open(f.name, 'r') as f2:
                wts = f2.read()    # this is an empty file
                print "wts:", wts
                # import sys
                sys.exit()

        return {
            'hyperparams': self._hyperparams,
            'dO': self._dO,
            'dU': self._dU,
            'scale': self.policy.scale,
            'bias': self.policy.bias,
            'tf_iter': self.tf_iter,
            'x_idx': self.policy.x_idx,
            'chol_pol_covar': self.policy.chol_pol_covar,
            'wts': wts,
        }

    # For unpickling.
    def __setstate__(self, state):
        print "state:", state
        from tensorflow.python.framework import ops
        ops.reset_default_graph()  # we need to destroy the default graph before re_init or checkpoint won't restore.
        self.__init__(state['hyperparams'], state['dO'], state['dU'])
        self.policy.scale = state['scale']
        self.policy.bias = state['bias']
        self.policy.x_idx = state['x_idx']
        self.policy.chol_pol_covar = state['chol_pol_covar']
        self.tf_iter = state['tf_iter']

        with tempfile.NamedTemporaryFile('w+b', delete=True) as f:
            print "we restore tf networks in file:", f.name
            print "state['wts']:", state['wts']
            f.write(state['wts'])
            f.seek(0)
            self.restore_model(f.name)

when i print wts, i found that wts is None. three files will be generated when we use saver.save(),they are checkpoint, *.index and *.data-*. so the file f2 is empty. how can we fix this bug?