toshas / torch_truncnorm

Truncated Normal Distribution in PyTorch
BSD 3-Clause "New" or "Revised" License
79 stars 13 forks source link

Invalid value when calling log_prob after sample #7

Open louisabraham opened 1 year ago

louisabraham commented 1 year ago

My code looks like

    m = TruncatedNormal(loc, scale, 0, 1)
    action_pt = m.sample()
    return m.log_prob(action_pt)

It looks like action_pt can take the value 1.0 and causes log_prob to raise an error:

    111     def log_prob(self, value):
    112         if self._validate_args:
--> 113             self._validate_sample(value)
    114         return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value**2) * 0.5
    115 

~/.pyenv/versions/3.8.8/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    291         valid = support.check(value)
    292         if not valid.all():
--> 293             raise ValueError(

I don't know if the error is:

  1. that the value 1.0 shouldn't be able to be picked
  2. that the value 1.0 is in the possible interval and shouldn't be called out as impossible
toshas commented 1 year ago

The exception is raised based on the support check, meaning that 1.0 doesn't land into the support interval. Since loc and scale aren't given in the snippet, it is hard to say if this is an issue with precision or incorrect usage of parameters. The interface was designed to follow conventions of the similar scipy function

louisabraham commented 1 year ago

Here is a reproducible example:

m = TruncatedNormal(torch.full((1000,), 2.), torch.full((1000,), .2), 0, 1)
m.log_prob(m.sample())

Some values are more than 1 and some are -inf:

tensor([0.9667, 0.9819, 0.9819, 0.9160, 0.9819, 0.9974, 0.9411, 0.9929, 0.9667,
        0.9819, 0.9667, 0.9160, 0.9667,   -inf, 0.9560,   -inf, 0.9560, 0.9160,
        0.9160, 0.9751, 0.9160, 0.9411, 1.0015,...
toshas commented 1 year ago

One thing I'd try first is plug these values in the unit test here https://github.com/toshas/torch_truncnorm/blob/main/tests/test.py#L97 and see if it passes the check against scipy. If not, there is a bug..

louisabraham commented 1 year ago

I added a line self._test_numerical(2.0, 0.2, 0.0, 1.0)

It gives:

======================================================================
FAIL: test_simple (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "testa.py", line 103, in test_simple
    self._test_numerical(2.0, 0.2, 0.0, 1.0)
  File "testa.py", line 74, in _test_numerical
    self.assertRelativelyEqual(mean_sc, mean_pt)
  File "testa.py", line 66, in assertRelativelyEqual
    raise self.failureException(msg)
AssertionError: array(0.96269921) != array(1.0022793, dtype=float32) within tol=1e-06 abs=1e-05 (rel=0.03949006605978869 diff=0.039580075041381724)

======================================================================
FAIL: test_support (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "testa.py", line 131, in test_support
    self.assertEqual(
AssertionError: 'Expected value argument (Tensor of shape [157 chars]10.0' != 'The value argument must be within the support'
+ The value argument must be within the support- Expected value argument (Tensor of shape ()) to be within the support (Interval(lower_bound=-1.0, upper_bound=2.0)) of the distribution TruncatedNormal(a: -1.0, b: 2.0), but found invalid values:
- -10.0

The second error is not due to my test, you might want to fix it in another issue. The first IS a bug.

Wu-Chenyang commented 1 year ago

The reason seems to be that extreme values for the icdf function should be clamped. https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py