Open ShibiHe opened 8 years ago
I found the indexing in build_function not right. You can run the code below to testify the wrong indexing in VS[:, A]
This indexing should be written like line 51, 52 in https://github.com/ShibiHe/DQN_OpenAI_keras/blob/master/agents.py or https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py
def build_functions(self): S = Input(shape=self.state_size) NS = Input(shape=self.state_size) A = Input(shape=(1,), dtype='int32') R = Input(shape=(1,), dtype='float32') T = Input(shape=(1,), dtype='int32') self.build_model() self.value_fn = K.function([S], self.model(S)) VS = self.model(S) VNS = disconnected_grad(self.model(NS)) future_value = (1-T) * VNS.max(axis=1, keepdims=True) discounted_future_value = self.discount * future_value target = R + discounted_future_value cost0 = VS[:, A] - target cost = ((VS[:, A] - target)**2).mean() opt = RMSprop(0.0001) params = self.model.trainable_weights updates = opt.get_updates(params, [], cost) self.train_fn = K.function([S, NS, A, R, T], [cost, cost0, target, A], updates=updates) # import numpy as np # t = self.train_fn([np.random.rand(10, *self.state_size), np.random.rand(10, *self.state_size), np.ones((10, 1)), np.ones((10, 1)), np.zeros((10, 1))]) # print('cost=', t[0]) # print('cost0=', t[1]) # print('target=', t[2]) # print('A=', t[3]) # raw_input()_
Hi @ShibiHe, thanks for your comment. You're right. This is a bug. I should be using np.arange(n) instead of :.
np.arange(n)
:
I found the indexing in build_function not right. You can run the code below to testify the wrong indexing in VS[:, A]
This indexing should be written like line 51, 52 in https://github.com/ShibiHe/DQN_OpenAI_keras/blob/master/agents.py or https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py