facebookresearch / mbrl-lib

Library for Model Based RL
MIT License
959 stars 158 forks source link

[Bug] Issue with loading best weights in ModelTrainer #98

Closed ethanluoyc closed 3 years ago

ethanluoyc commented 3 years ago

Hi, thanks for open-sourcing this library! It has really helped clarify some of the MBRL algorithms out there.

I have a specific question regarding this line

https://github.com/facebookresearch/mbrl-lib/blob/master/mbrl/models/model_trainer.py#L228

Maybe I am missing something, but I just want to clarify if this is actually loading the best weights according to the validation loss as expected? I am less familiar with PyTorch (I am mostly using JAX nowadays) but was under the impression that to perform load the model weights one calls module.load_state_dict(state_dict) instead of module.state_dict(state_dict). I checked the PyTorch documentation and it seems that the second function is putting the weights into state_dict instead of loading it. See

https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.state_dict

If this is indeed a bug, I would be very interested in seeing some updated results for the PETS implementation. I don't think this makes much of a difference for environments like CartPole or Pendulum, but it might be for tasks such as HalfCheetah.

*Sorry for not adopting the issue template. I feel this is a minor issue and the full template does not apply.

luisenp commented 3 years ago

Wow, you are indeed right, this is a bug! I'll fix that a re-run some experiments. Thanks a lot for catching this and reporting, and no worries about the template :)

ethanluoyc commented 3 years ago

There are a few other things I have found that is different between the PyTorch and the original TF version. I will create another issue.

luisenp commented 3 years ago

Hi @ethanluoyc. Sorry for taking so long to get back to you on this, the last few weeks have been hectic for me with other deadlines. I fixed this on a local branch and tried on HalfCheetah-v2 using MBPO but unfortunately didn't see much of a difference (slight improvement on the average over 10 seeds). I'll merge the fix anyway as this is clearly a bug, but I want to try on PETS before doing so; will let you know how it goes. I'll also start taking a look at #99 this week. Thanks again for reporting!

ethanluoyc commented 3 years ago

Thanks for letting me know! I think it would be interesting to see the results on PETS, the original implementation does not have the cross validation enabled but it would be interesting to see what happens.

luisenp commented 3 years ago

It actually has an impact even w/o validation, since in this case it saves the weights from the best training score. With PETS the situation is a bit unclear: with 10 seeds on HalfCheetah (using the original data transforms), more seeds reach ~15k reward than before, but other seeds actually end up performing worse (probably model overfitting to training, would be my guess).

In any case, I'll add the fix to master now, since it's not obviously breaking anything

ethanluoyc commented 3 years ago

These are actually very interesting observations. Yeah fixing it in master now sounds like a good plan!