siekmanj / r2l

Recurrent continuous reinforcement learning algorithms implemented in Pytorch.
Creative Commons Zero v1.0 Universal
51 stars 5 forks source link

LSTM Explanation #5

Closed npitsillos closed 3 years ago

npitsillos commented 3 years ago

Hello, good implementation of the DRL algos! I have been struggling quite a bit trying to understand how to incorporate an LSTM into the model and how to handle the hidden states and your repo seemed to be the clearest out of all I have seen. However, I am still not clear as to the exact usage of the hidden states so I hope you can clarify this for me. The specific part of the code to which I am referring is the one found at this file in lines 124 - 129.

Can you please clarify on that part of the code both in its significance and reasoning? Why not save hidden states as part of the buffer and sample from it as well? Another question is you have a list of hidden states which you update both after the policy update as well as during environment interaction, am I understanding this right? Finally is this the single most correct way to implement an LSTM layer in a DRL algorithm?

For example, initially the way I interpreted your code is the following which I adapted to my case with a single LSTM layer:

for sample in x: # x is batch, features
    x, hidden = self.lstm(x, hidden)
    append x to list of xs
forward xs list through actor & critic

If you could elaborate on how to use LSTMs in general I would appreciate it a lot and also point me to any other implementations or papers from which you got the inspiration to implement these, or have based them off?

siekmanj commented 3 years ago

Hey Nikolas,

Great questions! This repo should generally not be considered optimized for performance, and I encourage you to play around with optimizations like the one you describe.

Why not save hidden states as part of the buffer and sample from it as well? Another question is you have a list of hidden states which you update both after the policy update as well as during environment interaction, am I understanding this right?

When the policy is updated, the way that it computes hidden states changes due to the fact that its weights are different post-update. If you collect and then sample hidden states from the buffer, you are effectively sampling hidden states 'off-policy,' as they come from an old version of the policy. Whether or not this makes a difference to training is a question which can probably only be answered empirically.

In the current implementation, hidden states are indeed computed (but not necessarily re-computed) at every update step as well as during the exploration phase. I have seen some implementations which do re-use hidden states, but I am not convinced that this is a valid optimization and I would expect poor learning as a result.

Finally is this the single most correct way to implement an LSTM layer in a DRL algorithm?

I'm not sure about 'most correct,' but this implementation does seem to work really well for the problems I solve in my research. I think it is 'most correct' in the sense that hidden states are not re-used and trajectories are not truncated (meaning that gradients flow backwards from the end of the trajectory to the very beginning), which are both hacks I've seen used in other implementations to speed up computation which are not strictly speaking correct.

The relevant code you mention in r2l/base.py reads (with relevant comments inserted):

      y = []
      # Something to note: 'x' here is a 3 dimensional tensor, T x B x F
      # 'T' is the length of the trajectories (padded so they are all the same length)
      # 'B' is the batch size
      # 'F' is the feature size
      for t, x_t in enumerate(x):                                 # iterate through each timestep, x_t is dimension B x F
        for idx, layer in enumerate(self.layers):                 # iterate through each recurrent layer (idx corresponds to 'nth' lstm layer)
          c, h = self.cells[idx], self.hidden[idx]                      
          self.hidden[idx], self.cells[idx] = layer(x_t, (h, c))  # 'layer' is an LSTMCell object which processes one timestep at a time
          x_t = self.hidden[idx]                                  # output the hidden state to the next layer (note this overwrites the original x_t)
        y.append(x_t)                                             # output the final layer's output to our list, 'y'
      x = torch.stack([x_t for x_t in y])                         # concatenate each timestep's final output into one tensor

I see two potential problems with your code: first, there is no time dimension to your input tensor, it's just B x F. If this is wrapped inside another loop iterating over time, or you call this code once every timestep, then this should be fine.

Another thing to be aware of is the distinction between the LSTM module and the LSTMCell module in Pytorch: https://stackoverflow.com/questions/48187283/whats-the-difference-between-lstm-and-lstmcell

Some useful conceptual resources: https://blog.aidangomez.ca/2016/04/17/Backpropogating-an-LSTM-A-Numerical-Example/ http://karpathy.github.io/2015/05/21/rnn-effectiveness/

Ilya Kostrikov's PPO implementation (which I believe re-uses hidden states, so beware): https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail

npitsillos commented 3 years ago

@siekmanj thank you for fast reply, it seems I need to read up on this more because I fail to understand the significance of timesteps here, is it only because an LSTM is present? Assuming only linear layers, does that imply no timestep dimension? In the case of an LSTM you can also have a timestep dimension of 1 meaning you pass only 1 timestep of the trajectory where you have taken batch actions? In other words what is the correspondence between timstep and batch here. The reason I am asking is because in all my implementations I was using just BxF. Finally the batch of inputs needs to match the batch of hidden states?

So to close I will use a toy example, if we have a batch size of 512 then that means I have sampled 512 actions or states from the buffer. This means that the input to my actor is of dimension 512xobs_shape which corresponds 512 different timesteps in the experience? At least this is how I though of the batch in this setting.

I understand that I may not be as clear so please if you can do expand on this and we can discuss it because I am really interested to understand how to actually handle it as well as the intuitive aspect behind this.

npitsillos commented 3 years ago

I had a better look at the code it seems that when there is a recurrent policy then the trajectory idx is what achieves the timestep dimension? I will also download the code and play around to get a better feel for it.

npitsillos commented 3 years ago

@siekmanj I have downloaded and played around with the code and more specifically run a ppo agent on halfcheetah using an lstm architecture with the other parameters being the default. It seems to me that the way you are actually using the states when a recurrent policy is trained, is that you are iterating over trajectories. I found that with num_steps 5000, 4 workers and batch_size of 64 does not return any indices from the BatchSampler in a recurrent policy.

Is there a specific combination of num_workers and batch_size that is different when using recurrent policies? The way I understood it is that when a recurrent policy is present you need at least batch_size trajectories mean the length of traj_idx is at least batch_size for the sampler to return any indices which you then treat as timesteps x trajectories x features in the LSTM?

I can produce a script for this if you prefer in order to understand this better.

siekmanj commented 3 years ago

@siekmanj thank you for fast reply, it seems I need to read up on this more because I fail to understand the significance of timesteps here, is it only because an LSTM is present? Assuming only linear layers, does that imply no timestep dimension? In the case of an LSTM you can also have a timestep dimension of 1 meaning you pass only 1 timestep of the trajectory where you have taken batch actions? In other words what is the correspondence between timstep and batch here. The reason I am asking is because in all my implementations I was using just BxF.

In the process of collecting experience, you collect a number of 'rollouts' or 'trajectories' which are sequences of states, actions, and rewards. Each element in such a sequence has a non-IID distribution, meaning they are not independent events, because their sequential relationship implies dependence. This is what I mean by trajectory.

A batch, on the other hand, is either a batch of samples (B x F) in the case of a feedforward NN, or a batch of trajectories (T x B x F), wherein the batches are independently distributed from each other.

Recurrent neural networks are only effective on sequences of data where each point has some sequential relationship with other points, not a batch of data in which each point is independently distributed. I don't think there is any advantage in using an RNN on an trajectory whose length is 1.

Finally the batch of inputs needs to match the batch of hidden states?

I'm not entirely sure what you mean here, but if you mean the hidden states must be initialized to match the batch dimension, then yes, as is done here: https://github.com/siekmanj/r2l/blob/master/policies/base.py#L119

So to close I will use a toy example, if we have a batch size of 512 then that means I have sampled 512 actions or states from the buffer. This means that the input to my actor is of dimension 512xobs_shape which corresponds 512 different timesteps in the experience? At least this is how I though of the batch in this setting.

In feedforward networks, this is how batching works. Because RNNs operate on sequences though, a batch size of 512 really means you are training on batches of 512 trajectories. When you are training, you feed a tensor which is T x 512 x F to your policy, so there are T x 512 total timesteps sampled in all.

I had a better look at the code it seems that when there is a recurrent policy then the trajectory idx is what achieves the timestep dimension? I will also download the code and play around to get a better feel for it.

I am not sure what you mean. If you mean this traj_idx variable: https://github.com/siekmanj/r2l/blob/master/algos/ppo.py#L38, then this keeps track of how to index collected trajectories in the replay buffer.

I have downloaded and played around with the code and more specifically run a ppo agent on halfcheetah using an lstm architecture with the other parameters being the default. It seems to me that the way you are actually using the states when a recurrent policy is trained, is that you are iterating over trajectories. I found that with num_steps 5000, 4 workers and batch_size of 64 does not return any indices from the BatchSampler in a recurrent policy.

This is actually a bug I've been meaning to fix, but it's the result of hyperparameter selection. You are collecting 50,000 timesteps and then attempting to sample 64 trajectories. However, each trajectory is up to 1000 timesteps long, so you are over-sampling your buffer and trying to sample 64x1000=64,000 timesteps. The code really should not return an empty list of indices when this happens and I should get around to fixing this, but as it stands this is the behavior, yes. Try using a batch size of 32 or 16.

I understand that I may not be as clear so please if you can do expand on this and we can discuss it because I am really interested to understand how to actually handle it as well as the intuitive aspect behind this.

Anything by Andrej Karpathy will be good for intuition here. For intuition, this video (or similar videos on Youtube) might help: https://www.youtube.com/watch?v=LHXXI4-IEns

npitsillos commented 3 years ago

Thank you for the extensive feedback! It seems clearer to me now and I think I have understood what you mean and how to proceed! If it helps, just a thought, perhaps use drop_last = True in that case to return however many trajectories there are? Again thank you for being patient. It is my understanding that a lot of DRL algorithms are heavily depended on implementation specifics. So to reiterate the only caveat in recurrent policies is that you also need to keep track when trajectories have ended to have a sort of index on how to batch the trajectories.

siekmanj commented 3 years ago

Yep! You need to keep track of the start and end points of each trajectory in your replay buffer to do trajectory batching correctly. I think you actually would want to use drop_last=False rather than True, but if that's what you meant, I think flipping that value around will fix your issue as well, yes.