Closed JARACH-209 closed 3 years ago
@luisenp How about we just let the users pass the torch.nn.XXX activation function object itself. We wouldn't have to worry about handling the configurations like setting threshold values, lambda, etc.
activation = torch.nn.Threshold(0.5,10)
mlp = mbrl.models.GaussianMLP(..., activation_cls = activation)
Will also allow users to pass a custom activation as well! Defaults to ReLU.
Left a few more comments @JARACH-209, but I like your suggestions. Once we finalize the changes, we should also update the relevant configurations in conf/examples
, and also the example notebooks, to use the new API. And finally, we should add a unit test in this file to check that the model is created correctly.
@luisenp passing the class seemed a bit tricky to me. However, see the recent commit as I hope this should suffice. I have replaced activation_cls() with configured activation functions because prior would only work for activation functions that do not take any arguments.
@JARACH-209 did you delete this to make another PR from your master branch?
@luisenp yes it became quite confusing. I will open another PR with a unit test added to it.
Types of changes
Motivation and Context / Related issue
Added activation functions
https://github.com/facebookresearch/mbrl-lib/issues/41#issue-812559086
How Has This Been Tested (if it applies)
Not yet !
Checklist