facebookresearch / mbrl-lib

Library for Model Based RL
MIT License
952 stars 154 forks source link

[Bug-fix] Made truncated_normal_ completely in-place #141

Closed jan1854 closed 2 years ago

jan1854 commented 2 years ago

Types of changes

Motivation and Context / Related issue

The current implementation of mbrl.util.math.truncated_normal_() is not (entirely) in-place. This problem can be seen with the following code snippet.

import torch

import mbrl.util.math

t_original = torch.empty(1000)
t_new = mbrl.util.math.truncated_normal_(t_original)

print((t_original == t_new).all().item())
print(((t_original <= 2.0).all() and (t_original >= -2.0).all()).item())

With the current implementation, both outputs are False. If the implementation was in-place, both outputs should be True. The problem is that torch.where(), used in truncated_normal_(), is not an in-place operation. As a result, t_original contains values sampled from a regular Gaussian, while t_new contains values sampled from the truncated Gaussian.

How Has This Been Tested (if it applies)

import torch
import matplotlib.pyplot as plt

import mbrl.util.math

t_original = torch.empty(5000000)
t_new = mbrl.util.math.truncated_normal_(t_original)
assert (t_original == t_new).all().item()

plt.hist(t_original.cpu().numpy(), 100)
plt.show()

After the changes, the code snippet above results in the following histogram, which shows that the values are sampled from a Gaussian truncated to [-2, 2]. truncated_normal_results

Checklist

luisenp commented 2 years ago

Hi @JanS97, this is great, good catch! Thanks a lot for the fix. Do you mind adding a unit test along the lines of the first code snippet? (and probably best to use torch.allclose instead of == for the first check)

jan1854 commented 2 years ago

Yes, sure I can add a unit test. Actually, I think we should just check t_original is t_new since both variables should refer to the same object.

luisenp commented 2 years ago

Sounds good. Thanks!

jan1854 commented 2 years ago

I added the test. It checks both conditions from the first code snippet.

luisenp commented 2 years ago

LGTM. Merging now. Thanks @JanS97 !