ikostrikov / pytorch-a2c-ppo-acktr-gail

PyTorch implementation of Advantage Actor Critic (A2C), Proximal Policy Optimization (PPO), Scalable trust-region method for deep reinforcement learning using Kronecker-factored approximation (ACKTR) and Generative Adversarial Imitation Learning (GAIL).
MIT License
3.57k stars 829 forks source link

LSTM policy #15

Closed nadavbh12 closed 6 years ago

nadavbh12 commented 6 years ago

Great work on the implementation! Very comprehensible and straightforward implementation.

It seems you're performing two forward steps: 1) to choose an action (main.py, line 113), 2) to evaluate the actions (main.py, line 146). Why not save the values, log_probs and entropy in while selecting actions (as you did in a3c)? Are there computational benefits to performing these for all processes at once?

ikostrikov commented 6 years ago

In a3c code all computations are performed on CPU. But here I have a GPU option. At some point I compared these 2 strategies and it was faster to compute forward pass for each step and then concatenate, do extra forward pass and then a backward pass on this "large batch".

Probably, the current option is faster because PyTorch doesn't concatenate all inputs for a backward pass and GPU kernels are optimizer for larger batches.

nadavbh12 commented 6 years ago

Interesting. It makes sense, though the current implementation makes it more difficult to implement a recurrent policy. Though I guess whatever way to "batchify" a recurrent policy won't be pleasant in this case (considering that you need to reset the hidden vectors for different envs at different time steps).

ikostrikov commented 6 years ago

I agree. It's more difficult to implement a recurrent policy. But it's just marginally more difficult.

I was going to add a recurrent policy in the second half of November anyway. But I can provide you all necessary pointers so you can do it yourself.

nadavbh12 commented 6 years ago

No need for recurrent at the moment. I was thinking after n steps to split the matrix in two where the game ends, then use the padded sequence and run them both. What's your plan?

ikostrikov commented 6 years ago

The plan was to not reshape them here but send to the model unreshaped.

https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/main.py#L146

So the FF model reshapes them to batchify and an RNN policy will just unroll it. And then we will also need to "duplicate" RNN states. But it should be relatively simple. We just need to store temporary states in first loop and overwrite in the second.

nadavbh12 commented 6 years ago

Though how do you plan to zero the hidden vector of terminal states? As I see it you can either: 1) Run the RNN step by step in the "num_steps" loop, zero the hiddens only for processes that reached a terminal state (in this option there's no need for the additional forward afterwards). 2) Split the mini-batch in two where terminal and use padded sequences. As in: Forward first part, zero needed hiddens, forward second part.

I tend towards the first option for clarity and since I don't think you'll have the speed gains you saw earlier with the recurrent policy.

ikostrikov commented 6 years ago

I was going to implement it in the following way.

1) save recurrent states before the first loop; 2) unroll to get rewards in a standard way; 3) in the second forward loop retrieve the states saved in 1) and unroll again without reshaping the inputs

This way it will be just easier to reuse the same code for PPO because we can just sample rollouts for different processes in 3)

ikostrikov commented 6 years ago

I'm going to add an LSTM policy this week. Do you know a good task to test it? In most papers they use FF policies only.

nadavbh12 commented 6 years ago

Papers on VisDoom and DeepMind Lab usually use LSTMs, though I haven't tried it myself.

Kaixhin commented 6 years ago

If you're looking for domains where RNNs tend to perform better than FF models, Atari games with longer term dependencies (rather than those which rely on quick reaction times) are an option. I believe Frostbite, Ms. Pac-Man and Seaquest are appropriate. Unfortunately sometimes the FF models may still do better on these games, so it's not a surefire way to check for differences in performance.

ikostrikov commented 6 years ago

Thanks! I will try these games.

At the moment, when I add LSTMs it doesn't even work on the same games.

nadavbh12 commented 6 years ago

Did you implement an LSTM module for masking the terminal states? Or some did you do something else?

ikostrikov commented 6 years ago

Yes, I can push it to a new branch.

nadavbh12 commented 6 years ago

Cool. I'll give it a look when it's up.

ikostrikov commented 6 years ago

It's in the lstm branch: https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/tree/lstm

You can run it with

python main.py --env-name "PongNoFrameskip-v4" --recurrent-policy
ikostrikov commented 6 years ago

After fixing the initialization, the recurrent policy seems to work for a2c (working on ppo as well).

See the LSTM branch.

nadavbh12 commented 6 years ago

Very nice! What was the problem with the orthogonal initialization? I can't see a diff to the Pytorch version. Was it the rows < cols case?

ikostrikov commented 6 years ago

Yes, it was rows < cols case. It was fixed in September but it's still not in the version installed with conda.

nadavbh12 commented 6 years ago

Good thing I sent a pull request :-)

ethancaballero commented 6 years ago

any timeline on merging lstm branch into master branch?

ikostrikov commented 6 years ago

OpenAI recently released their LSTM PPO. I'm going to derive my code from their implementation now, it it will probably take a week.

ikostrikov commented 6 years ago

A recurrent policy is in the main branch now.