Open krzentner opened 3 years ago
Can you add a little bit more explanation for the design here? I'm concerned about using an ADT as the blanket input to policies, which makes the interface pretty complicated even in the simplest use cases.
The core motivation here is to provide a way for recurrent and non-recurrent policies to share the same API at optimization time.
However, I definitely agree that making this change has shown me that it significantly increases the complexity of the garage.torch APIs. It doesn't result in a significant increase in complexity in any algorithm (except for TutorialVPG
), but it's noticeable.
In the future, I also generally intended for this datatype to play a role (although with a very different design to) state_info_spec
in the TF branch.
This PR only adds the bare minimum fields needed for recurrent policies to have reasonable .forward
methods. However, we could replace the observation
field on PolicyInput
by instead having PolicyInput
inherit from torch.Tensor
.
Then, algorithms that only want to train stochastic non-recurrent policies (i.e. SAC), could just pass a torch.Tensor
(as they do now). Alternatively, we could use a helper function at the start of every torch policies .forward
method to convert any torch.Tensor
input into a PolicyInput
(in SHUFFLED
mode).
This change does not yet pass tests, but is 90% complete.