pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

`TransformedDistribution` support too broad when using `AffineTransform` transformation? #1756

Closed arik-shurygin closed 4 months ago

arik-shurygin commented 4 months ago

Running the following code displays the issue I am running into.

import numpyro.distributions as dist

transformed_dist = dist.TransformedDistribution(dist.Beta(5,5), dist.transforms.AffineTransform(loc=1, scale = 1))

print("Beta Dist Support")
print(transformed_dist.base_dist.support)

print("AffineTransform codomain")
print(transformed_dist.transforms[-1].codomain)

print("TransformedDist support")
print(transformed_dist.support)

When using the following TransformedDistribution my NUTS sampler is initially sampling values outside of the support expected for a Beta distribution. I believe this is because transformed_dist.support is Real() when it should realistically be numpyro.distributions.constraints._Interval(1.0, 2.0).

1) Is this expected behavior? 2) would this cause cause issues with initial samples from NUTS being out of expected range?

Love the package, thanks guys!

fehiepsi commented 4 months ago

The affine transform has a domain parameter. you can set it to unit_interval.

arik-shurygin commented 4 months ago

apologies for the oversight, this solves my issue. Thanks for the quick response.