rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.86k stars 309 forks source link

Pytorch Categorical GRU Policy #2196

Open ManavR123 opened 3 years ago

ManavR123 commented 3 years ago

In this PR, I add a Pytorch version of the Categorical GRU Policy. I believe there is some in-progress work with adding RNN support #2172, however, this PR seems to be in draft for a while and I personally needed this policy so I figured I would try and contribute. One thing missing would be an example using this policy, but it isn't too different from other similar policies.

I would appreciate any feedback on this!

@krzentner @avnishn

avnishn commented 3 years ago

Oh wait one more thing @ManavR123.

You'll want to include a test similar to:

tests/garage/tf/models/test_categorical_gru_model.py

The test titled test normalized outputs checks to ensure that your module outputs the correct values in a golden values test.

This will ensure correctness.

You can also obtain these golden values by using your module/policy in some benchmark to ensure that it works, then running the policy and module on a simplified environment like grid world env, capturing the weights of that policy/module, and then writing a test that compares the values of a trained policy to those golden values after a few optimization iterations.

Does that make sense? Please ask me anything if that wasn't clear!

Thanks @avnishn

krzentner commented 3 years ago

Wow, this PR looks good. The main feedback I have is that we don't actually yet have any algorithms implemented in PyTorch that can train an RNN. Unfortunately, our VPG (and PPO and TRPO) shuffle all timesteps individually in their batching step, which prevents training the RNN correctly. I'm currently working on fixing that, but probably won't have anything usable available before sometime next week.

The more minor feedback I have is that I think we should actually streamline the garage/torch/modules directory to not include modules that output distributions whenever possible, since it's just an extra layer to reason about when working with policies (which themselves have a super-set of the API of the equivalent module).

ryanjulian commented 3 years ago

@krzentner I'm curious: what is your implementation plan for RNNs in torch/VPG? Is it to make the optimization trajectory-oriented (as in tf/VPG) or something else?

nikhilxb commented 3 years ago

Wow, this PR looks good. The main feedback I have is that we don't actually yet have any algorithms implemented in PyTorch that can train an RNN. Unfortunately, our VPG (and PPO and TRPO) shuffle all timesteps individually in their batching step, which prevents training the RNN correctly. I'm currently working on fixing that, but probably won't have anything usable available before sometime next week.

The more minor feedback I have is that I think we should actually streamline the garage/torch/modules directory to not include modules that output distributions whenever possible, since it's just an extra layer to reason about when working with policies (which themselves have a super-set of the API of the equivalent module).

@krzentner Did this end up getting done? I'm also trying to implement an RNN policy in PyTorch but can't seem to find any working examples.