TuringLang / DistributionsAD.jl

Automatic differentiation of Distributions using Tracker, Zygote, ForwardDiff and ReverseDiff
MIT License
151 stars 30 forks source link

ForwardDiff.jl + derivative of parameters of a truncated distribution = NaNs everywhere #242

Open torfjelde opened 1 year ago

torfjelde commented 1 year ago

The following is currently the case:

julia> h(θ) = logpdf(truncated(Normal(θ, 1), 0, Inf), 1.0)
h (generic function with 1 method)

julia> ForwardDiff.derivative(h, rand())
NaN

julia> g(θ) = logpdf(truncated(Normal(θ, 1), 1e-6, 1000), 1.0)
g (generic function with 1 method)

julia> ForwardDiff.derivative(g, rand())
-0.2321285832954859

IIRC, this has come up before? It comes down the usage of the cdf in the computation of the truncated log-pdf, which causes issues.

@sethaxen did we talk about this over Slack at some point? Feel like there was a thread about this issue.

devmotion commented 1 year ago

This has come up quite a few times but fortunately the solution is easy: Use NaN-safe mode in ForwardDiff (by default, it does return incorrect result for infinite values with zero partials) or use the keyword argument syntax of truncated (truncated(Normal(...); lower=0)). The latter has the additional advantage that it avoids undesired promotions and, in the future, that you can dispatch on left- and right-truncated distributions (just opened a PR to Distributions a few days ago).