NVlabs / GA3C

Hybrid CPU/GPU implementation of the A3C algorithm for deep reinforcement learning.
BSD 3-Clause "New" or "Revised" License
649 stars 195 forks source link

LSTM version #3

Open markovyao opened 7 years ago

markovyao commented 7 years ago

It is a great work. Is there any plan to develop a LSTM version?

ifrosio commented 7 years ago

Not immediately, but it shouldn't be hard to implement it in TF. If you have any version with LSTM, please let us know.

ieow commented 7 years ago

Is it possible to implement lstm in this ga3c architecture? RNN (lstm) required serialize input, but based on this ga3c architecture which push exp to queue from multiple agent would not make the 'exp' serial input. Thus, batch input for training thread would be mixed and cannot be used as RNN training input. Correct me if I am wrong. Thanks

adi-sharma commented 7 years ago

Should be straight forward, as the state for Atari games is already defined as 4 frames together (See section 4.1 of the original DQN paper - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) and that is what GA3C uses. If you supply those frames serially, the LSTM version of GA3C will work.

mbz commented 7 years ago

Implementing the LSTM version without lots of code change depends on how long the sequences of training data should be. If the sequences are as long as TMAX frames (which I think is the case) then the current architecture works since Trainers receive sequences of TMAX frames. But if the training data should be any longer (i.e. multiple TMAX frames merged together) it becomes a little bit more complicated.

etienne87 commented 7 years ago

in case of LSTM, shouldn't the batch be organized in (N, T, C, H, W) format?

mbz commented 7 years ago

@etienne87 you are correct. But please look at here. What Trainer receives is in (N, T, C, H, W) format but it merges the T dimension to have data in (N, C, H, W) format. In a recurrent model these concatenations are unnecessary.

etienne87 commented 7 years ago

@mbz, thanks for pointing to this. Now i am super confused with this part of the code! can you take a look at #6 ? I don't see how these concatenations are working at all!

I would suggest to modify ThreadTrainer.py to :

           if self.server.model.rnn:
                print('todo')
            else:
                while batch_size <= Config.TRAINING_MIN_BATCH_SIZE:
                    x_, r_, a_ = self.server.training_q.get()
                    if batch_size == 0:
                        x__ = x_; r__ = r_; a__ = a_
                    else:
                        x__ = np.concatenate((x__, x_))
                        r__ = np.concatenate((r__, r_))
                        a__ = np.concatenate((a__, a_))     
                    batch_size += x_.shape[0]
etienne87 commented 7 years ago

LSTM would require reset_state func to address a specific row from the batch right?

class NetworkVP():
    [...]
    def reset_state(self, idx):
        #todo...
        self.lstm_state_c[idx,...] = 0
        self.lstm_state_h[idx,...] = 0

sorry for pseudocode, not expert with TF.

etienne87 commented 7 years ago

Another confusion I have about this, (because little experience with TF). It seems we need 2 graphs : one for prediction (taking a dynamic_rnn), and one (maybe taking a static tf.rnn?) for backprop (if feeding (N, T, C, H, W) , or is there a way to use a gradient applier like in myosuda ? Sorry if this is not really the good place to ask.

mbz commented 7 years ago

@etienne87 I'm not sure if I understand your first question about reset_state correctly. Can you please provide more details?

About having separate graphs, there are different ways of implementing the same logic in TF. We are not using separate graphs simply because it's not necessary. Can you please leverage why you think having two graphs is necessary?

etienne87 commented 7 years ago

@mbz ok! What I mean : In classic A3C, it seems we can just backprop at the end of an episode (T_MAX), by just re-using the already computed predictions. On the other hand, here, it seems we need to recompute the predictions with the samples and actions. In short : X should be (N, H, W, C) in predictor thread, (N, T, H, W, C) in the train function? Maybe I misunderstood something about TF internal mechanism?

Also, the thing about reset : at beginning of each episode you probably want to reset to zero the c & h of your lstm. So as @ppwwyyxx is suggesting, lstm state should be saved inside each ProcessAgent ?

ppwwyyxx commented 7 years ago

@mbz I have implemented A3C-LSTM with long sequence length. You don't have to send the whole sequence into the graph. What I did is to maintain the current LSTM hidden state for every game simulator in Python side, and every time feed the new frame together with the hidden state of each simulator to the graph. This way the sequence length could be as long as one episode.

markovyao commented 7 years ago

@ppwwyyxx I have built an LSTM and stored the hidden states.
However, I got two questions. 1. where to reset or initial the stored states before each episode? 2. how to deal with the class Experience in LSTM training?

ppwwyyxx commented 7 years ago

@markovyao The states were maintained in python, inside each agent (simulator), so you can easily set them when needed (e.g. right after the agent reaches the end of episode). Since each agent maintains its own hidden states, it can do the following by its own:

  1. keep the hidden states in its own experience history buffer and give it to the network for training
  2. send the hidden states to predictor to get the next action
  3. request the predictor to send back the next hidden state and keep it
ricky1203 commented 7 years ago

@markovyao alternative implement to @ppwwyyxx solution:

  1. create matrix vars to store LSTM hidden states
  2. every agent assign an unique agent_index
  3. use tf.gather to select the hidden states from matrix to pass to the tf.nn.dynamic_rnn init_state = tf.gather(_init_state, self._input_agent_indexs)
  4. use tf.scatter_update to reset/update LSTM hidden states accord the agent_index/last_is_over, for example:
        if tc.is_training:
            need_reset_states = tf.reshape(tf.ones_like(self._input_is_over) - self._input_is_over, (-1, 1))
            op_updates = [tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, rnn_output_states_array[idx] * tf.cast(need_reset_states, rnn_output_states_array[idx].dtype)) \
                          for idx in range(len(rnn_output_states_array))]
        else:
            # in predict mode, the is_over is for last state
            batch_size = tf.shape(self._input_agent_indexs)[0]
            op_updates = []
            for idx in range(len(initial_rnn_states)):
                shape_states = tf.shape(initial_rnn_states[idx])
                op = tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, tf.zeros((batch_size,shape_states[1]), dtype=initial_rnn_states[idx].dtype))
                op_resets.append(op)
                op = tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, rnn_output_states_array[idx])
                op_updates.append(op)
  5. in predict/train, call the op update/reset when needed

    this implement is useful if you have many LSTM network or in frequent modification development, for it only export update/reset ops outside model

etienne87 commented 7 years ago

@ricky1203 : could you perhaps provide an example/ link in context?

ricky1203 commented 7 years ago

@etienne87 check the def _create_rnn_from_cell() in model.py

note: for hidden states stored in model, agent should predict/train in one model(GPU device) during one episode

Golly commented 7 years ago

@etienne87 Do you have success with developing LSTM pls?

etienne87 commented 7 years ago

@Golly not so much to be honest. Also I think I first need to test idea referred in #16; Otherwise LSTM version will need re-computation of TMAX steps before each backward & update.

etienne87 commented 7 years ago

Coming back to this problem with a slightly more understanding on with variable length rnn : I think the easiest way to code the LSTM version is to keep track of c, h states in Experiences Queues.

In ThreadTrainer::run :

 while not self.exit_flag:
            batch_size = 0
            ids = []
            lengths = []
            while batch_size <= Config.TRAINING_MIN_BATCH_SIZE:
                idx, x_, r_, a_, c_, h_ = self.server.training_q.get()
                lengths.append(x_.shape[0])
                if batch_size == 0:
                    x__ = x_; r__ = r_; a__ = a_; c__ = c_; h__ = h_;
                else:
                    x__ = np.concatenate((x__, x_))
                    r__ = np.concatenate((r__, r_))
                    a__ = np.concatenate((a__, a_))
                    c__ = np.concatenate((c__,c_))
                    h__ = np.concatenate((h__,h_))

                ids.append(idx)
                batch_size += x_.shape[0]

            if Config.TRAIN_MODELS:
                self.server.train_model(x__, r__, a__,c__,h__, lengths) 

In NetworkVP::_create_graph

self.d1 = ... #result of feedforward encoder
self.lstm = rnn.BasicLSTMCell(256, state_is_tuple=True)
self.step_sizes = tf.placeholder(tf.int32, [None], name='stepsize') #given by ThreadTrainer, otherwise assume np.ones((batch_predict_size))
batch_size = tf.shape(self.step_sizes)[0]    
d1 = tf.reshape(self.d1, [batch_size,-1,256]) #this will not work without a special function
self.c0 = tf.placeholder(tf.float32, [None, 256])
self.h0 = tf.placeholder(tf.float32, [None, 256])
self.initial_lstm_state = rnn.LSTMStateTuple(self.c0,self.h0)  
lstm_outputs, self.lstm_state = tf.nn.dynamic_rnn(self.lstm,d1,
                                                        initial_state = self.initial_lstm_state,
                                                        sequence_length = self.step_sizes,
                                                        time_major = False))
self._state = tf.reshape(lstm_outputs, [-1,256])  #pass this vector to pi, v

In NetworkVP::predict_p_and_v:

step_sizes = np.ones((c.shape[0],),dtype=np.int32)
feed_dict = self.__get_base_feed_dict()
feed_dict.update({self.x: x, self.step_sizes:step_sizes, self.c0:c, self.h0:h})
p, v, rnn_state = self.sess.run([self.softmax_p, self.logits_v, self.lstm_state], feed_dict=feed_dict)
return p, v, rnn_state.c, rnn_state.h

In NetworkVP::train:

step_sizes = np.array(lengths)
feed_dict = self.__get_base_feed_dict()
feed_dict.update({self.x: x,  self.y_r: r, self.action_index: a, self.step_sizes:step_sizes, self.c0:c, self.h0:h})
r = np.reshape(y_r,(y_r.shape[0],))
self.sess.run(self.train_op, feed_dict=feed_dict)

I think the only thing i am missing is how to sort of "unpack" sequence of encoded states in _create_graph method :

d1 = tf.reshape(self.d1, [batch_size,-1,256]) will not work when sequence lengths are variable, does anybody know TF enough to tell me how to use `step_sizes' in order to create a list of (nstep, 256) tensors?

etienne87 commented 7 years ago

Anyway, there is a first implementation that works fine if you don't have too much underachieved experiences (of length < Config.TIME_MAX) here

I "solved" the issue by padding sequences in ThreadTrainer.py.

In order to be optimal, we would need to dynamically batch the data after the feedforward encoder (before the LSTM), in order to feed a (N, TIME_MAX, 256) Tensor to tf.dynamic_rnn; However I am not convinced this really slows down the process as most of experience batches should be full (sequence length is TIME_MAX).

I will now test on Pong, fuse with GAE branch. If someone wants to help me understand how to improve this you are welcome! :-)

etienne87 commented 7 years ago

Hum, Actually there was still an error in my code, I forgot to mask the loss for padding inputs!

I propose a first fix here

Apparently this now works better (at least for CartPole-v0)

In Config.py :

    TIME_MAX = 5
    STACKED_FRAMES = 4
    IMAGE_WIDTH = 1
    IMAGE_HEIGHT = 4
    EPISODES = 4000
    ANNEALING_EPISODE_COUNT = 4000
    BETA_START = 0.01
    BETA_END = 0.01
    LEARNING_RATE_START = 0.0003
    LEARNING_RATE_END = 0.0003
    RMSPROP_DECAY = 0.99
    RMSPROP_MOMENTUM = 0.0
    RMSPROP_EPSILON = 0.1
    DUAL_RMSPROP = False
    USE_GRAD_CLIP = False
    GRAD_CLIP_NORM = 40.0 
    LOG_EPSILON = 1e-6
    TRAINING_MIN_BATCH_SIZE = 16
    USE_RNN = True
    NCELLS = 256
    MIN_POLICY = 0.0
    USE_LOG_SOFTMAX = True

ga3c_lstm_vs_ff

wgeul commented 4 years ago

TIME_MAX

Out of interest, can I ask why you've removed this page? What were your findings wrt performance of the addition of LSTM?

Edit: Found your model here: https://github.com/etienne87/GA3C , thanks!