toshas / torch_truncnorm

Truncated Normal Distribution in PyTorch
BSD 3-Clause "New" or "Revised" License
79 stars 13 forks source link

Plot of pdf seems to suggest that the truncation is faulty #1

Closed pierresegonne closed 3 years ago

pierresegonne commented 3 years ago

Hey! Thanks for sharing this!

To test briefly the truncated normal distribution, I wanted to verify the pdf and wrote the following snippet

import matplotlib.pyplot as plt
import torch

from sggm.vae_model_helper import TruncatedNormal

a = 0
b = 1
mu = 0.8
std = 0.5

p = TruncatedNormal(loc=mu, scale=std, a=a, b=b, validate_args=True)

x = torch.linspace(-0.5, 1.5, 100)

fig, ax = plt.subplots()
ax.plot(x, p.log_prob(x).exp().flatten())

plt.show()

Resulting in the following plot profile. Did I completely miss the point or is the truncation not happening ?

Thanks :) Figure_1

toshas commented 3 years ago

Thanks for your report! The reason to this behavior is that the check for support doesn't trigger here: https://github.com/toshas/torch_truncnorm/blob/main/TruncatedNormal.py#L90-L93

The actual values within the support should be correct. You can add a check in the given context and either return -inf or throw an exception.

pierresegonne commented 3 years ago

Ah right! Thanks for the answer!

Do you want me to propose a fix for this ? :)

toshas commented 3 years ago

@pierresegonne thank you, a pull request would be great! I think it'd make sense to copy the behavior of other distributions with a limited support, such as e.g. Beta

toshas commented 3 years ago

Committed https://github.com/toshas/torch_truncnorm/commit/7ad2e22e2a26a6cb53234befa06754367b3ad0e0 and detailed in the subsequent commits. Thanks for the initial pull request, I will close it now with this ticket.