rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.86k stars 310 forks source link

Rework garage.torch.optimizers #2177

Open krzentner opened 3 years ago

krzentner commented 3 years ago

This change does not yet pass tests, but is 90% complete.

ryanjulian commented 3 years ago

Can you add a little bit more explanation for the design here? I'm concerned about using an ADT as the blanket input to policies, which makes the interface pretty complicated even in the simplest use cases.

krzentner commented 3 years ago

The core motivation here is to provide a way for recurrent and non-recurrent policies to share the same API at optimization time. However, I definitely agree that making this change has shown me that it significantly increases the complexity of the garage.torch APIs. It doesn't result in a significant increase in complexity in any algorithm (except for TutorialVPG), but it's noticeable. In the future, I also generally intended for this datatype to play a role (although with a very different design to) state_info_spec in the TF branch.

This PR only adds the bare minimum fields needed for recurrent policies to have reasonable .forward methods. However, we could replace the observation field on PolicyInput by instead having PolicyInput inherit from torch.Tensor. Then, algorithms that only want to train stochastic non-recurrent policies (i.e. SAC), could just pass a torch.Tensor (as they do now). Alternatively, we could use a helper function at the start of every torch policies .forward method to convert any torch.Tensor input into a PolicyInput (in SHUFFLED mode).