devsisters / DQN-tensorflow

Tensorflow implementation of Human-Level Control through Deep Reinforcement Learning
MIT License
2.46k stars 765 forks source link

[!] Load FAILED error #14

Open SumbeeLei opened 7 years ago

SumbeeLei commented 7 years ago

Hi, I was trying to run the DQN code. when it iterated 50000 steps, an error [!] Load FAILED happened. According to the error information. CPU only supports data format "NHWC", but the code executed by gpu with data format "NCHW". Thus, I want to know how to execute gpu with "NCHW" and save or load cpu with "NHWC" to avoid this error . THX!!

code modified by me following(not work):

  def save_model(self, step=None):
    print(" [*] Saving checkpoints...")
    **self.config.cnn_format = "NHWC"**
    print("******** save begin data_formate %s", self.config.cnn_format);
    model_name = type(self).__name__

    if not os.path.exists(self.checkpoint_dir):
      os.makedirs(self.checkpoint_dir)
    self.saver.save(self.sess, self.checkpoint_dir, global_step=step)
    **self.config.cnn_format = "NCHW"**
    print("******** save end data_formate %s", self.config.cnn_format);

  def load_model(self):
    print(" [*] Loading checkpoints...")
    **self.config.cnn_format = "NHWC"**
    print("******** load begin data_formate %s", self.config.cnn_format);

    ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
      ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
      fname = os.path.join(self.checkpoint_dir, ckpt_name)
      self.saver.restore(self.sess, fname)
      print(" [*] Load SUCCESS: %s" % fname)
      **self.config.cnn_format = "NCHW"**
      print("******** load end data_formate %s", self.config.cnn_format);
      return True
    else:
      print(" [!] Load FAILED: %s" % self.checkpoint_dir)
      **self.config.cnn_format = "NCHW"**
      print("******** load end data_formate %s", self.config.cnn_format);
      return False