facebookresearch / mbrl-lib

Library for Model Based RL
MIT License
959 stars 158 forks source link

Add configurable activation function to MLP models #41 #124

Closed JARACH-209 closed 3 years ago

JARACH-209 commented 3 years ago

Types of changes

Motivation and Context / Related issue

Added activation functions

https://github.com/facebookresearch/mbrl-lib/issues/41#issue-812559086

How Has This Been Tested (if it applies)

Not yet !

Checklist

JARACH-209 commented 3 years ago

@luisenp How about we just let the users pass the torch.nn.XXX activation function object itself. We wouldn't have to worry about handling the configurations like setting threshold values, lambda, etc.

activation = torch.nn.Threshold(0.5,10)
mlp = mbrl.models.GaussianMLP(..., activation_cls = activation)

Will also allow users to pass a custom activation as well! Defaults to ReLU.

luisenp commented 3 years ago

Left a few more comments @JARACH-209, but I like your suggestions. Once we finalize the changes, we should also update the relevant configurations in conf/examples, and also the example notebooks, to use the new API. And finally, we should add a unit test in this file to check that the model is created correctly.

JARACH-209 commented 3 years ago

@luisenp passing the class seemed a bit tricky to me. However, see the recent commit as I hope this should suffice. I have replaced activation_cls() with configured activation functions because prior would only work for activation functions that do not take any arguments.

luisenp commented 3 years ago

@JARACH-209 did you delete this to make another PR from your master branch?

JARACH-209 commented 3 years ago

@luisenp yes it became quite confusing. I will open another PR with a unit test added to it.