Closed jan1854 closed 2 years ago
Hi @JanS97, this is great, good catch! Thanks a lot for the fix. Do you mind adding a unit test along the lines of the first code snippet? (and probably best to use torch.allclose
instead of ==
for the first check)
Yes, sure I can add a unit test. Actually, I think we should just check t_original is t_new
since both variables should refer to the same object.
Sounds good. Thanks!
I added the test. It checks both conditions from the first code snippet.
LGTM. Merging now. Thanks @JanS97 !
Types of changes
Motivation and Context / Related issue
The current implementation of
mbrl.util.math.truncated_normal_()
is not (entirely) in-place. This problem can be seen with the following code snippet.With the current implementation, both outputs are
False
. If the implementation was in-place, both outputs should beTrue
. The problem is thattorch.where()
, used intruncated_normal_()
, is not an in-place operation. As a result,t_original
contains values sampled from a regular Gaussian, whilet_new
contains values sampled from the truncated Gaussian.How Has This Been Tested (if it applies)
After the changes, the code snippet above results in the following histogram, which shows that the values are sampled from a Gaussian truncated to [-2, 2].
Checklist