toshas / torch_truncnorm

Truncated Normal Distribution in PyTorch
BSD 3-Clause "New" or "Revised" License
79 stars 13 forks source link
distributions pytorch truncated-normal

torch_truncnorm

Truncated Normal distribution in PyTorch. The module provides:

Why

I just needed differentiation with respect to parameters of the distribution and found out that truncated normal distribution is not bundled in torch.distributions as of 1.6.0.

Known issues

icdf is numerically unstable; as a consequence, so is rsample. This issue is also seen in torch.distributions.normal.Normal, so it is sort of normal (ba-dum-tss).

Tests

CUDA_VISIBLE_DEVICES=0 python -m tests.test

Links

https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf