Closed ethanluoyc closed 3 years ago
Wow, you are indeed right, this is a bug! I'll fix that a re-run some experiments. Thanks a lot for catching this and reporting, and no worries about the template :)
There are a few other things I have found that is different between the PyTorch and the original TF version. I will create another issue.
Hi @ethanluoyc. Sorry for taking so long to get back to you on this, the last few weeks have been hectic for me with other deadlines. I fixed this on a local branch and tried on HalfCheetah-v2 using MBPO but unfortunately didn't see much of a difference (slight improvement on the average over 10 seeds). I'll merge the fix anyway as this is clearly a bug, but I want to try on PETS before doing so; will let you know how it goes. I'll also start taking a look at #99 this week. Thanks again for reporting!
Thanks for letting me know! I think it would be interesting to see the results on PETS, the original implementation does not have the cross validation enabled but it would be interesting to see what happens.
It actually has an impact even w/o validation, since in this case it saves the weights from the best training score. With PETS the situation is a bit unclear: with 10 seeds on HalfCheetah (using the original data transforms), more seeds reach ~15k reward than before, but other seeds actually end up performing worse (probably model overfitting to training, would be my guess).
In any case, I'll add the fix to master
now, since it's not obviously breaking anything
These are actually very interesting observations. Yeah fixing it in master now sounds like a good plan!
Hi, thanks for open-sourcing this library! It has really helped clarify some of the MBRL algorithms out there.
I have a specific question regarding this line
https://github.com/facebookresearch/mbrl-lib/blob/master/mbrl/models/model_trainer.py#L228
Maybe I am missing something, but I just want to clarify if this is actually loading the best weights according to the validation loss as expected? I am less familiar with PyTorch (I am mostly using JAX nowadays) but was under the impression that to perform load the model weights one calls
module.load_state_dict(state_dict)
instead ofmodule.state_dict(state_dict)
. I checked the PyTorch documentation and it seems that the second function is putting the weights intostate_dict
instead of loading it. Seehttps://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.state_dict
If this is indeed a bug, I would be very interested in seeing some updated results for the PETS implementation. I don't think this makes much of a difference for environments like CartPole or Pendulum, but it might be for tasks such as HalfCheetah.
*Sorry for not adopting the issue template. I feel this is a minor issue and the full template does not apply.