Alfredvc / paac

Open source implementation of the PAAC algorithm presented in Efficient Parallel Methods for Deep Reinforcement Learning
https://arxiv.org/abs/1705.04862
Other
206 stars 59 forks source link

add LSTM layer #2

Open chihchiehchen opened 7 years ago

chihchiehchen commented 7 years ago

Hello,

May I ask a naive question, did you try to implement LSTM on this architecture? Or you already did it and find it is not efficient (maybe time consuming?) as people think.

In any case thanks for not such harware-demanding idea/architecture.

Best, Chih-Chieh

Alfredvc commented 7 years ago

Hi Chih,

I have implemented LSTM, and will probably release that code once i have some time. I did not do extensive tests so i can't really speak of its performance. However it did learn.

chihchiehchen commented 7 years ago

Hello Alfredo,

I also begin to implement LSTM (not finished yet, just as an exercise), maybe later will try another architecture (echo state machine?) and see if it is comparable to LSTM or not. In any case it is good to have a comparable score even on a computer with only 4 core CPU.

Thanks a lot and wait for your good news.

Best, Chih-Chieh

Alfredvc commented 7 years ago

Just a quick tip, it may be a little tricky to implement an LSTM here. This is because you may have experiences from different episodes (from the same environment) in the same minibatch. What this means is that you have to reset the state of the LSTM within a minibatch, something that is not supported by Tensorflow's dynamic_rnn.

I will attach the code for the LSTM networ, hope it will help you,

    class LSTMNetwork(Network):

    def __init__(self, conf):
        super(LSTMNetwork, self).__init__(conf)

        self.lstm_size = conf['lstm_size']
        self.lstm_state = (tf.zeros((self.emulator_counts, self.lstm_size), dtype=tf.float32),
                           (tf.zeros((self.emulator_counts, self.lstm_size), dtype=tf.float32)))

        with tf.device(self.device):
            with tf.name_scope(self.name):
                # 0.0 if the episode was over on the previous timestep, else 1.0
                self.prev_episode_over_mask_ph = tf.placeholder(tf.float32, [None, self.emulator_counts], name='episode_over_mask')
                self.lstm_state_c = tf.Variable(self.lstm_state[0], trainable=False, dtype=tf.float32)
                self.lstm_state_h = tf.Variable(self.lstm_state[1], trainable=False, dtype=tf.float32)

                self.reset_state = tf.group(tf.assign(self.lstm_state_c, self.lstm_state[0]),
                                            tf.assign(self.lstm_state_h, self.lstm_state[1]))

                stored_lstm_state_c = tf.Variable(self.lstm_state[0], trainable=False, dtype=tf.float32)
                stored_lstm_state_h = tf.Variable(self.lstm_state[1], trainable=False, dtype=tf.float32)

                stored_lstm_state = tf.contrib.rnn.LSTMStateTuple(self.lstm_state_c, self.lstm_state_h)

                input_dim = self.output.get_shape().as_list()[1]

                reshaped_input = tf.reshape(self.output, [-1, self.emulator_counts, input_dim])

                max_time = tf.cast(tf.gather(tf.cast(tf.shape(reshaped_input), dtype=tf.float32), 0), dtype=tf.int32)
                self.max_time = max_time

                inputs_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True, name='tensor_array_inputs', clear_after_read=False)
                inputs_ta = inputs_ta.unstack(reshaped_input, 'unstack_inputs')
                episode_over_ta = tf.TensorArray(dtype=tf.float32, size=self.emulator_counts, dynamic_size=True, name='tensor_array_episode_over')
                episode_over_ta = episode_over_ta.unstack(self.prev_episode_over_mask_ph, 'unstack_episode_over_mask')

                cell = tf.contrib.rnn.BasicLSTMCell(self.lstm_size)

                def loop_fn(time, cell_output, cell_state, loop_state):
                    nonlocal max_time
                    emit_output = cell_output  # == None for time == 0

                    if cell_output is None:  # time == 0
                        use_cell_state = stored_lstm_state
                    else:
                        use_cell_state = cell_state

                    episode_over = tf.expand_dims(episode_over_ta.read(time), axis=1)
                    # If episode ended, next cell state should be zero.
                    next_cell_state = tf.contrib.rnn.LSTMStateTuple(tf.multiply(use_cell_state[0], episode_over),
                                       tf.multiply(use_cell_state[1], episode_over))
                    elements_finished = tf.greater_equal(time, max_time, name='elements_finished')
                    # If time = max_time then the loop will finish and next_input will not be used,
                    # this minimum effectively replaces a conditional statement by reusing the last input.
                    next_input = inputs_ta.read(tf.minimum(time, max_time-1))
                    next_loop_state = None
                    return (elements_finished, next_input, next_cell_state,
                            emit_output, next_loop_state)

                outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn)

                update_lstm_state = tf.group(tf.assign(self.lstm_state_c, final_state[0]),
                                                  tf.assign(self.lstm_state_h, final_state[1]))
                with tf.control_dependencies([update_lstm_state]):
                    self.output_stack = outputs_ta.stack()

                self.output = tf.reshape(self.output_stack, [-1, self.lstm_size])

                self.store_lstm_state = tf.group(tf.assign(stored_lstm_state_c, self.lstm_state_c),
                                                 tf.assign(stored_lstm_state_h, self.lstm_state_h))

                self.rollback_lstm_state = tf.group(tf.assign(self.lstm_state_c, stored_lstm_state_c),
                                                    tf.assign(self.lstm_state_h, stored_lstm_state_h))

                # Normal flow
                # 1. store_lstm_state
                # 2. output + update_lstm_state
                # 3. repeat #2 max_local_steps timesteps
                # 4. output + update_lstm_state
                # 5. rollback_lstm_state
                # 6. output + update weights + update_lstm_state
chihchiehchen commented 7 years ago

Hello Alfredo,

Thanks a lot, I really learn something from your reply. I will continue to work on it and maybe can provide new information later.

Thanks a lot.

Best, Chih-Chieh

zencoding commented 7 years ago

@Alfredvc I don't understand the need to reset the LSTM with every minibatch. I would assume that we want the LSTM cell to learn across minibatch otherwise the LSTM is being retrained on every minibatch, am I missing something?

Alfredvc commented 7 years ago

What happens is that each minibatch contains the experience from the different environments for a set number of timesteps. Given that these environments are episodic, once an episode is over you must reset the state of the lstm. It would make no sense to backpropagate across episodes.

If episodes only terminated at the end of a minibatch you could just run the minibatch, and then reset the state of the lstm corresponding to the environments who's episodes have terminated. However episodes may terminate at any point in the minibatch, meaning that you must be able to reset the state of the lstm even within a minibatch.

So the idea is not to reset the state after each minibatch, but to reset it after the episode ends.

I hope this makes sense.

Alfredvc commented 7 years ago

Yes that is how you would implement experience replay for the lstm architecture. And it is similar to what the code does. However it is currently only only for on-policy data, so the "replay memory" is just the experience you have gathered since the last update.

I think I may have misunderstood what @zencoding meant by resetting the state after each minibatch. I think you may be referring to the rollback in the comment? In that case, before you do one step of optimization, as @pisiiki says, you must roll back the state of the lstm to the state it had before it first encountered the first transition in the minibatch.

zencoding commented 7 years ago

@Alfredvc thanks for the clarification, this helps. I am wondering why is it not done in most implementation that uses LSTM, for example, https://github.com/zencoding/DeepRL-Agents/blob/master/A3C-Doom.ipynb which is an on-policy A3C implementation, is it due to batching performed in paac?

I was trying to get MountainCar working on paac but I gave up after numerous changes to hyperparameters, it was just not learning anything. Reading online and other github implementations for MountainCar, it seems that it is not built to be solved without Experience Replay (or some kind of higher exploration than on-policy). I looked at DDPG and ACER, I liked ACER and it is closer to paac.

I am going to attempt to reimplement ACER (from https://github.com/Kaixhin/ACER) in paac after I get ACER working for MountainCar. This code will be very helpful as LSTM helps a lot when it comes to temporal learning.

Alfredvc commented 7 years ago

Sorry for the super late response, I was on vacation for a while!

The issue with the LSTM is usualy resolved in one of two ways, having a dynamic batch size or padding. Neither of these solutions are efficient with the batching done in PAAC. To implement a dynamic batch size across all environments would mean that you would have to perform a step of optimization as soon as any environment terminates an episode. This may not lead to a large loss in performace if episodes terminate in around 1000 timesteps and 16 instances are used, but it is inefficient. Padding could also be used, but then the environment that has terminated must be "paused" until the next step of optimization is performed, which is again inefficient. Resetting the state of an LSTM within a batch however avoid the inefficiencies of the two other approaches (for PAAC).

I have done experiments with experience replay using something similar to ACER, and also different technique that i have not yet presented. I will probably present both of this at some point in the future.