JuliaStats / Distributions.jl

A Julia package for probability distributions and associated functions.
Other
1.11k stars 416 forks source link

`truncated(d, l, Inf)` fails with AD #1910

Open penelopeysm opened 1 day ago

penelopeysm commented 1 day ago

Using truncated with ±Inf as the bounds tends to lead to NaN's when using automatic differentiation:

julia> using Distributions; f(s) = logpdf(truncated(Normal(0.0, s[1]), 0, +Inf), 2)
f (generic function with 1 method)

julia> import ForwardDiff; ForwardDiff.gradient(f, [0.1])
1-element Vector{Float64}:
 NaN

The fact that NaN's are returned isn't so much a problem on its own, the issue is more that it creates an easy but non-obvious trap for users to fall into - e.g. in #1189 but I've also seen it on other Turing.jl stuff.

A cheap fix might be:

diff --git a/src/truncate.jl b/src/truncate.jl
index 48d62b01..bf8d379e 100644
--- a/src/truncate.jl
+++ b/src/truncate.jl
@@ -62,6 +62,8 @@ end
 truncated(d::UnivariateDistribution, ::Nothing, ::Nothing) = d
 function truncated(d::UnivariateDistribution, l::T, u::T) where {T <: Real}
     l <= u || error("the lower bound must be less or equal than the upper bound")
+    l == -Inf && return truncated(d, nothing, u)
+    u == Inf && return truncated(d, l, nothing)

     # (log)lcdf = (log) P(X < l) where X ~ d
     loglcdf = _logcdf_noninclusive(d, l)

I recognise that in principle it isn't really the job of Distributions to fix this, but the patch is so small that it shouldn't be a maintenance burden, so I figured it was worth suggesting 🙂

devmotion commented 1 day ago

This is a known issue and was suggested in e.g. #1730 (also related: #1467). I think the type instability would be a quite unfortunate consequence of such a change. Maybe an alternative that would avoid this problem would be to throw an exception or show a warning when a non-nothing bound is equal to the endpoint of the support.

penelopeysm commented 23 hours ago

Hmm, I clearly didn't dig back far enough when searching issues. Agreed on the importance of type stability, especially given the context that this isn't an issue with Distributions itself.

Even a warning feels a bit like a suboptimal solution? as I assume truncated distributions with +-Inf work perfectly fine within Distributions itself and the warning would be noise to anyone who was just doing that – it's only the usage with other packages where one might like to be warned.

devmotion commented 22 hours ago

It's generally always preferable to use nothing instead of an endpoint of the untruncated distribution. The former allows to optimize to calculations with, and possibly even the returned type of, truncated.