facebookresearch / mbrl-lib

Library for Model Based RL
MIT License
959 stars 158 forks source link

[Question] Model Interface: forward arguments #101

Closed axelbr closed 3 years ago

axelbr commented 3 years ago

I am not sure how to use the forward function of the Model interface in my subclass as the interface only accepts a single input argument. If i want to have a control input, should i 1) concat the obs and action tensor, 2) use the TransitionBatch class or 3) introduce another function argument with default values (such as action=None)? What would be the preferred alternative for you?

Thanks!

axelbr commented 3 years ago
  1. model. forward(x=torch.cat([obs, action])
  2. model.forward(x=batch)
  3. def model.forward(x: TensorType, u: TensorType = None): ...
luisenp commented 3 years ago

Hi @axelbr. It looks like this is a place where the current documentation is a bit confusing. In fact, for the model forward method, you can actually use any number of input arguments. What needs to be defined to use the ModelTrainer is the loss(model_input, optimizer, target=None) method. This method is called by the default update method, which is in turn called by ModelTrainer using a TransitionBatch as model input (and note that it passes target=None). Here, there are two cases:

  1. You want to have flexibility to use your model class with arbitrary tensors as inputs, w/o needed a transition batch every time. In this case, you can have another wrapper model that converts transition batch to tensors, creates the target, and then pass the correct tensor inputs and target to the model. This is the case for GaussianMLP (which only works with Tensors), and the wrapper OneDTransitionRewardModel. This wrapper converts TransitionBatch into torch.cat([obs, action]) along with some other transformations, and also creates the target from batch.next_obs and batch.reward.

  2. You are fine dealing with TransitionBatch directly. In this the model itself takes care of all the data manipulation and then can pass any tensors to forward. This is the case for the PlaNet implementation we are currently working on. Note that its forward method is defined as forward(obs: torch.Tensor, action: torch.Tensor).

Does this make sense? Do let me know if this is still confusing, I'm happy to clarify further. BTW, if you are working with 1-D observations and actions, I suggest to take a look at OneDTransitionRewardModel for batch manipulation, and have your model class just focus on the modeling part, dealing only with torch Tensors.

axelbr commented 3 years ago

Thanks for the extensive answer, it is clearer now. I just thought that implementations have to follow the interface more strictly. I will close this issue now. If further questions arise, I will post them here.

Thanks for maintaining this codebase!