Khrylx / PyTorch-RL

PyTorch implementation of Deep Reinforcement Learning: Policy Gradient methods (TRPO, PPO, A2C) and Generative Adversarial Imitation Learning (GAIL). Fast Fisher vector product TRPO.
MIT License
1.09k stars 186 forks source link

Training a recurrent policy #4

Open erschmidt opened 6 years ago

erschmidt commented 6 years ago

I am still struggling with the implementation of a recurrent policy. The trick from #1 worked and I can now start running my RNN GAIL Network. But no matter what I try the mean reward is actually decreasing over time.

I am currently using the same ValueNet and Advantage Estimation as in the repository.

Do I have to change something in trpo_step in order to make RNN Policies work?

Thank you so much!

Khrylx commented 6 years ago

The value function also needs to be an RNN, or you can pass the rnn output from the policy to the value net.

sandeepnRES commented 5 years ago

Any help regarding how to use RNN? should it be used single step, or over an episode(specifying times steps as sequences)? and how will the backpropagation take place, all at once at the end of episode? Do you have any code regarding this available?

Khrylx commented 5 years ago

I don't have any code for RNN yet. But I can imagine how it can be done.

Suppose the agent collected a batch of three episodes of length 5,6,7, so the total length is 18. You need to pad these episodes to be of the same length 7. So you will have an input size 7 x 3 x d, where d is your input dim. Then you pass it through an LSTM, which will give you an output of size 7 x 3 x h, where h is the output dim. You reshape it into 21 x h, and find where the episode steps (5,6,7) correspond to by indexing operation, then you will get the right output of 18 x h. Then you can just pass this to an MLP or any other network you want. All these operations are achievable by pytorch.

sandeepnRES commented 5 years ago

Okay, but should the functions like get_log_probability which is used in ppo step, be updated? because they compute forward propagation, getting new hidden vector after every backpropagation. or should the new hidden vector be ignored, and old hidden vector be used(that was obtained during data collection by agent).