Closed axelbr closed 3 years ago
Hi @axelbr. It looks like this is a place where the current documentation is a bit confusing. In fact, for the model forward
method, you can actually use any number of input arguments. What needs to be defined to use the ModelTrainer
is the loss(model_input, optimizer, target=None)
method. This method is called by the default update method, which is in turn called by ModelTrainer
using a TransitionBatch
as model input (and note that it passes target=None
). Here, there are two cases:
You want to have flexibility to use your model class with arbitrary tensors as inputs, w/o needed a transition batch every time. In this case, you can have another wrapper model that converts transition batch to tensors, creates the target, and then pass the correct tensor inputs and target to the model. This is the case for GaussianMLP (which only works with Tensors), and the wrapper OneDTransitionRewardModel. This wrapper converts TransitionBatch
into torch.cat([obs, action])
along with some other transformations, and also creates the target from batch.next_obs
and batch.reward
.
You are fine dealing with TransitionBatch
directly. In this the model itself takes care of all the data manipulation and then can pass any tensors to forward
. This is the case for the PlaNet implementation we are currently working on. Note that its forward method is defined as forward(obs: torch.Tensor, action: torch.Tensor)
.
Does this make sense? Do let me know if this is still confusing, I'm happy to clarify further. BTW, if you are working with 1-D observations and actions, I suggest to take a look at OneDTransitionRewardModel
for batch manipulation, and have your model class just focus on the modeling part, dealing only with torch Tensors.
Thanks for the extensive answer, it is clearer now. I just thought that implementations have to follow the interface more strictly. I will close this issue now. If further questions arise, I will post them here.
Thanks for maintaining this codebase!
I am not sure how to use the forward function of the Model interface in my subclass as the interface only accepts a single input argument. If i want to have a control input, should i 1) concat the obs and action tensor, 2) use the TransitionBatch class or 3) introduce another function argument with default values (such as action=None)? What would be the preferred alternative for you?
Thanks!