Open louisabraham opened 1 year ago
The exception is raised based on the support check, meaning that 1.0 doesn't land into the support interval. Since loc and scale aren't given in the snippet, it is hard to say if this is an issue with precision or incorrect usage of parameters. The interface was designed to follow conventions of the similar scipy function
Here is a reproducible example:
m = TruncatedNormal(torch.full((1000,), 2.), torch.full((1000,), .2), 0, 1)
m.log_prob(m.sample())
Some values are more than 1 and some are -inf:
tensor([0.9667, 0.9819, 0.9819, 0.9160, 0.9819, 0.9974, 0.9411, 0.9929, 0.9667,
0.9819, 0.9667, 0.9160, 0.9667, -inf, 0.9560, -inf, 0.9560, 0.9160,
0.9160, 0.9751, 0.9160, 0.9411, 1.0015,...
One thing I'd try first is plug these values in the unit test here https://github.com/toshas/torch_truncnorm/blob/main/tests/test.py#L97 and see if it passes the check against scipy. If not, there is a bug..
I added a line self._test_numerical(2.0, 0.2, 0.0, 1.0)
It gives:
======================================================================
FAIL: test_simple (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
File "testa.py", line 103, in test_simple
self._test_numerical(2.0, 0.2, 0.0, 1.0)
File "testa.py", line 74, in _test_numerical
self.assertRelativelyEqual(mean_sc, mean_pt)
File "testa.py", line 66, in assertRelativelyEqual
raise self.failureException(msg)
AssertionError: array(0.96269921) != array(1.0022793, dtype=float32) within tol=1e-06 abs=1e-05 (rel=0.03949006605978869 diff=0.039580075041381724)
======================================================================
FAIL: test_support (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
File "testa.py", line 131, in test_support
self.assertEqual(
AssertionError: 'Expected value argument (Tensor of shape [157 chars]10.0' != 'The value argument must be within the support'
+ The value argument must be within the support- Expected value argument (Tensor of shape ()) to be within the support (Interval(lower_bound=-1.0, upper_bound=2.0)) of the distribution TruncatedNormal(a: -1.0, b: 2.0), but found invalid values:
- -10.0
The second error is not due to my test, you might want to fix it in another issue. The first IS a bug.
The reason seems to be that extreme values for the icdf
function should be clamped. https://github.com/pytorch/rl/blob/main/torchrl/modules/distributions/truncated_normal.py
My code looks like
It looks like
action_pt
can take the value1.0
and causeslog_prob
to raise an error:I don't know if the error is: