toshas / torch_truncnorm

Truncated Normal Distribution in PyTorch
BSD 3-Clause "New" or "Revised" License
79 stars 13 forks source link

Incorrect mean predictions for distributions with a loc much smaller than a #9

Closed TheAeryan closed 1 year ago

TheAeryan commented 1 year ago

I have observed that an incorrect mean value (by calling the mean method) is predicted when the distribution has a loc values that is much smaller than its lower bound a. More specifically, it predicts a mean lower than a, even though the mean of a truncated gaussian always lies within the bound interval [a,b].

Example:

>>> from truncated_gaussian import TruncatedNormal
>>> TruncatedNormal(0,1,-1,1).mean
tensor(0.)
>>> TruncatedNormal(-0.5,1,-1,1).mean
tensor(-0.1437)
>>> TruncatedNormal(-1,1,-1,1).mean
tensor(-0.2772)
>>> TruncatedNormal(-2,1,-1,1).mean
tensor(-0.4900)
>>> TruncatedNormal(-3,1,-1,1).mean
tensor(-0.6294)
>>> TruncatedNormal(-4,1,-1,1).mean
tensor(-0.7173)
>>> TruncatedNormal(-5,1,-1,1).mean
tensor(-0.7797)
>>> TruncatedNormal(-6,1,-1,1).mean
tensor(-1.0114)
>>> TruncatedNormal(-7,1,-1,1).mean
tensor(-6.9490)
>>> TruncatedNormal(-10,1,-1,1).mean
tensor(-10.)
>>> TruncatedNormal(-100000,1,-1,1).mean
tensor(-100000.)

As you can see, when the mu (loc parameter of the distribution) is equal to -6 or below, the mean of the distribution gets below the bound a=-1, even though this should not happen.

toshas commented 1 year ago

Thanks for your interest in the package! The API of this code follows the one by scipy. A few relevant test cases can be found here: https://github.com/toshas/torch_truncnorm/blob/main/tests/test.py

After importing the compatible scipy wrapper from the test source code, I managed to get the same values as my code:

IN: TruncatedNormalSC(-4,1,-1,1).mean
OUT: -0.7173056200576982

Check out the reference explaining the meaning of the parameters here: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html

Let me know if there are any remaining concerns!

TheAeryan commented 1 year ago

Hi, thank you very much for your quick response.

I don't know if I understand your answer. Firstly, in order to use your code I have simply downloaded the TruncatedNormal.py script and imported it from Python. Is this wrong?

Secondly, in your response you predict the mean of TruncatedNormalSC(-4,1,-1,1) as -0.7173, but this is the same value that appears in my example. I don't know if this value is correct or not but, at first glance, it does not seem to be wrong as it lies between a=-1 and b=1.

Could you try calculating the mean of another distribution such as TruncatedNormal(-10,1,-1,1).mean? If the scipy wrapper works the same as your TruncatedNormal.py script, then it should incorrectly predict its mean to be -10 (even though it cannot be lower than -1).

toshas commented 1 year ago

Thanks for the clarification; I initially misunderstood which case fails. I was able to reproduce the issue with TruncatedNormal(-10,1,-1,1).mean and TruncatedNormal(torch.tensor(-10, dtype=torch.float64),1.,-1.,1.).mean. When inspecting the intermediate values in the TruncatedStandardNormal.__init__ function, these are the values:

a = tensor(9., dtype=torch.float64) 
b = tensor(11., dtype=torch.float64) 
_little_phi_a = tensor(1.0280e-18, dtype=torch.float64)
_little_phi_b = tensor(2.1188e-27, dtype=torch.float64)
_Z = tensor(2.2204e-16, dtype=torch.float64) 

I guess the current code is not using the floating point precision efficiently in computing phis and Z. If you have suggestions on how this could be improved, I'd be happy to consider a fix.

TheAeryan commented 1 year ago

Hi,

Thank you for your interest. I was discussing this issue today with my collaborator at the MIT-IBM, who knows more than I do about numerical stability. He told me about a few techniques for performing these operations in a stable manner, although I need to look into it (this is the first time I try to solve this kind of issue). My idea is to dedicate the next few days to obtain a numerically-stable truncated gaussian. Then, I can submit a pull request, if that's fine for you :)

A few clarifications about the code would be useful, though. 1) I don't completely understand the wrapper thing. The way I am using your code is simply downloading your TruncatedNormal.py script and importing the TruncatedNormal class. Is this not the intended way to use your code? 2) What does the following code do? little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan)

Best Regards, Carlos

El jue, 20 jul 2023 a las 18:16, Anton Obukhov @.***>) escribió:

Thanks for the clarification; I initially misunderstood which case fails. I was able to reproduce the issue with TruncatedNormal(-10,1,-1,1).mean and TruncatedNormal(torch.tensor(-10, dtype=torch.float64),1.,-1.,1.).mean. When inspecting the intermediate values in the TruncatedStandardNormal.init function, these are the values:

a = tensor(9., dtype=torch.float64) b = tensor(11., dtype=torch.float64) _little_phi_a = tensor(1.0280e-18, dtype=torch.float64) _little_phi_b = tensor(2.1188e-27, dtype=torch.float64) _Z = tensor(2.2204e-16, dtype=torch.float64)

I guess the current code is not using the floating point precision efficiently in computing phis and Z. If you have suggestions on how this could be improved, I'd be happy to consider a fix.

— Reply to this email directly, view it on GitHub https://github.com/toshas/torch_truncnorm/issues/9#issuecomment-1644691799, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJ6P7CHJK6MSPCJKN5WZ3YDXRGU3PANCNFSM6AAAAAA2R3UYLM . You are receiving this because you authored the thread.Message ID: @.***>

guicho271828 commented 1 year ago

Hi, I am a mentor of @TheAeryan . I agree that the issue is numerical precision, especially where you are subtracting between small floats. Specifically, these all look dangerous:

The numerical issue is noted on wikipedia too. https://en.wikipedia.org/wiki/Truncated_normal_distribution#Two_sided_truncation[3]

Two directions for a fix are:

TheAeryan commented 1 year ago

I will look into those and see if I can re-implement them in Python.

TheAeryan commented 1 year ago

Hi, again

For the past few days I have been implementing a numerically-stable version based on github.com/cossio/TruncatedNormal.jl/blob/master/src/tnmean.jl. The issue is I have implemented it from scratch so it is not backwards-compatible with your code (among other things, I have not implemented methods for sampling and calculating the entropy). For this reason, I have not submitted a pull-request but created a new github repo: github.com/TheAeryan/stable-truncated-gaussian.

Additionally, I would like to clarify that your code works just fine as long as the mu parameter is inside the interval [a,b] given by the bounds. Problems only arise only when mu is outside this interval (or when a and b are very close to each other).