MouseHu / GEM

13 stars 4 forks source link

problem when runing ddq6: update_func #2

Open lizhuo-1994 opened 2 years ago

lizhuo-1994 commented 2 years ago

Hi, thanks for open the source code, I appreciate you work on episodic RL. I ran your code locally but failed as follow:

Traceback (most recent call last): File "train.py", line 10, in env, agent, buffer, learner, tester = experiment_setup(args) File "/home/lizhuo/workspace/GEM/gem_atari/common.py", line 198, in experiment_setup args.agent = agent = create_agent(args) File "/home/lizhuo/workspace/GEM/gem_atari/algorithm/init.py", line 31, in create_agent return algorithm_collection[args.alg]0 File "/home/lizhuo/workspace/GEM/gem_atari/algorithm/ddq.py", line 42, in init self.args.buffer.update_func(self) AttributeError: 'ReplayBuffer_FrameStack' object has no attribute 'update_func'

This is the code in ddq6.py:

class DDQ6(BaseLearner): def init(self, args, flags={}): super(DDQ6, self).init() self.args = args self.gpu = args.gpu self.flags = flags self.acts_num = args.acts_dims[0] self.inner_q_type = args.inner_q_type

    self.num_q = self.args.num_q
    self.tau = self.args.tau
    self.alpha = self.args.alpha
    self.beta = self.args.beta

    self.q_funcs = []
    self.q_pi_funcs = []
    self.target_q_funcs = []
    self.target_q_pi_funcs = []

    self.meta_q_funcs = []
    self.meta_q_pi_funcs = []
    self.target_meta_q_funcs = []
    self.target_meta_q_pi_funcs = []

    self.target_qs = None
    self.qs = None
    self.meta_q_funcs_stack = None
    self.meta_q_pi = None

    self.meta_target_check_range = None
    self.create_model()

    self.train_info = {
        'Q_loss': self.q_loss,
        'Meta_Q_loss': self.meta_q_loss,
        # 'Q_target_0': self.q_step_target[:, 0],
        # 'Q_target_1': self.q_step_target[:, 1],
        'difference': self.buffer_target_diffence,
        'target_range': self.target_check_range,
        'meta_target_range': self.meta_target_check_range,
        'regression_target': self.qvalues_ph,
        'true_return': self.true_rews_ph,
    }
    self.step_info = {
        'Q_average': self.meta_q_pi,
        'sub_Q_average': self.q_pi,

    }

    self.args.buffer.update_func(self)

However, I cannot find the function "update_func()" in the replay_buffer.py

Could you please tell the solutions or if I made some mistakes?

lizhuo-1994 commented 2 years ago

Sorry, solved!